diff options
Diffstat (limited to 'sqlglot')
42 files changed, 727 insertions, 275 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index e829517..04c3195 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,4 +1,6 @@ -"""## Python SQL parser, transpiler and optimizer.""" +""" +.. include:: ../README.md +""" from __future__ import annotations @@ -30,7 +32,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.2.9" +__version__ = "10.4.2" pretty = False diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index 42a54bc..f9613b2 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -1,9 +1,15 @@ import argparse +import sys import sqlglot parser = argparse.ArgumentParser(description="Transpile SQL") -parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile") +parser.add_argument( + "sql", + metavar="sql", + type=str, + help="SQL statement(s) to transpile, or - to parse stdin.", +) parser.add_argument( "--read", dest="read", @@ -48,14 +54,20 @@ parser.add_argument( args = parser.parse_args() error_level = sqlglot.ErrorLevel[args.error_level.upper()] +sql = sys.stdin.read() if args.sql == "-" else args.sql + if args.parse: sqls = [ repr(expression) - for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level) + for expression in sqlglot.parse( + sql, + read=args.read, + error_level=error_level, + ) ] else: sqls = sqlglot.transpile( - args.sql, + sql, read=args.read, write=args.write, identify=args.identify, diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py index e69de29..a57e990 100644 --- a/sqlglot/dataframe/__init__.py +++ b/sqlglot/dataframe/__init__.py @@ -0,0 +1,3 @@ +""" +.. include:: ./README.md +""" diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi index 67c8c09..1682ec1 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -9,18 +9,8 @@ if t.TYPE_CHECKING: from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.types import StructType -ColumnLiterals = t.TypeVar( - "ColumnLiterals", - bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], -) -ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) -ColumnOrLiteral = t.TypeVar( - "ColumnOrLiteral", - bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], -) -SchemaInput = t.TypeVar( - "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]] -) -OutputExpressionContainer = t.TypeVar( - "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert] -) +ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +ColumnOrName = t.Union[Column, str] +ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] +OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 3c45741..a17bb9d 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -634,7 +634,7 @@ class DataFrame: all_columns = self._get_outer_select_columns(new_df.expression) all_column_mapping = {column.alias_or_name: column for column in all_columns} if isinstance(value, dict): - values = value.values() + values = list(value.values()) columns = self._ensure_and_normalize_cols(list(value)) if not columns: columns = self._ensure_and_normalize_cols(subset) if subset else all_columns diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6be68ac..d10cc54 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,11 +1,15 @@ +"""Supports BigQuery Standard SQL.""" + from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + datestrtodate_sql, inline_array_sql, no_ilike_sql, rename_func, + timestrtotime_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -120,13 +124,12 @@ class BigQuery(Dialect): "NOT DETERMINISTIC": TokenType.VOLATILE, "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, - "WINDOW": TokenType.WINDOW, } KEYWORDS.pop("DIV") class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "DATE_TRUNC": _date_trunc, "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), @@ -144,31 +147,33 @@ class BigQuery(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, # type: ignore "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), } FUNCTION_PARSERS.pop("TRIM") NO_PAREN_FUNCTIONS = { - **parser.Parser.NO_PAREN_FUNCTIONS, + **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { - *parser.Parser.NESTED_TYPE_TOKENS, + *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore TokenType.TABLE, } class Generator(generator.Generator): TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", + exp.DateStrToDate: datestrtodate_sql, + exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", @@ -176,6 +181,7 @@ class BigQuery(Dialect): exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), + exp.TimeStrToTime: timestrtotime_sql, exp.VariancePop: rename_func("VAR_POP"), exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, @@ -188,7 +194,7 @@ class BigQuery(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.INT: "INT64", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index cbed72e..7136340 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -35,13 +35,13 @@ class ClickHouse(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "MAP": parse_var_map, } - JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} + JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore def _parse_table(self, schema=False): this = super()._parse_table(schema) @@ -55,7 +55,7 @@ class ClickHouse(Dialect): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.MAP: "Map", @@ -70,7 +70,7 @@ class ClickHouse(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index c87f8d8..e788852 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect): def rename_func(name): def _rename(self, expression): args = flatten(expression.args.values()) - return f"{name}({self.format_args(*args)})" + return f"{self.normalize_func(name)}({self.format_args(*args)})" return _rename @@ -217,11 +217,11 @@ def if_sql(self, expression): def arrow_json_extract_sql(self, expression): - return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}" + return self.binary(expression, "->") def arrow_json_extract_scalar_sql(self, expression): - return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}" + return self.binary(expression, "->>") def inline_array_sql(self, expression): @@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression): expression.args.get("substr"), expression.this, expression.args.get("position") ) return f"LOCATE({args})" + + +def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: + return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" + + +def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: + return f"CAST({self.sql(expression, 'this')} AS DATE)" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 358eced..4e3c0e1 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, create_with_partitions_sql, + datestrtodate_sql, format_time_lambda, no_pivot_sql, no_trycast_sql, rename_func, str_position_sql, + timestrtotime_sql, ) -from sqlglot.dialects.postgres import _lateral_sql def _to_timestamp(args): @@ -117,14 +118,14 @@ class Drill(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", @@ -139,14 +140,13 @@ class Drill(Dialect): ROOT_PROPERTIES = {exp.PartitionedByProperty} TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.Lateral: _lateral_sql, exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), exp.Create: create_with_partitions_sql, exp.DateAdd: _date_add_sql("ADD"), - exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", @@ -160,7 +160,7 @@ class Drill(Dialect): exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f1da72b..81941f7 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + datestrtodate_sql, format_time_lambda, no_pivot_sql, no_properties_sql, @@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, rename_func, str_position_sql, + timestrtotime_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -83,11 +85,12 @@ class DuckDB(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, + "CHARACTER VARYING": TokenType.VARCHAR, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, @@ -119,16 +122,18 @@ class DuckDB(Dialect): STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: rename_func("LIST_VALUE"), + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else rename_func("LIST_VALUE")(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), exp.DataType: _datatype_sql, exp.DateAdd: _date_add, exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""", - exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", exp.Explode: rename_func("UNNEST"), @@ -136,6 +141,7 @@ class DuckDB(Dialect): exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), @@ -150,7 +156,7 @@ class DuckDB(Dialect): exp.Struct: _struct_pack_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("EPOCH"), @@ -163,7 +169,7 @@ class DuckDB(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 8d6e1ae..088555c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import ( rename_func, strposition_to_local_sql, struct_extract_sql, + timestrtotime_sql, var_map_sql, ) from sqlglot.helper import seq_get @@ -197,7 +198,7 @@ class Hive(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( @@ -217,7 +218,12 @@ class Hive(Dialect): ), unit=exp.Literal.string("DAY"), ), - "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), + "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( + [ + exp.TimeStrToTime(this=seq_get(args, 0)), + seq_get(args, 1), + ] + ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, @@ -240,7 +246,7 @@ class Hive(Dialect): } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, # type: ignore TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties( expressions=self._parse_wrapped_csv(self._parse_property) ), @@ -248,14 +254,14 @@ class Hive(Dialect): class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, @@ -294,7 +300,7 @@ class Hive(Dialect): exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: _time_to_str, exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 7627b6e..0fd7992 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -161,8 +161,6 @@ class MySQL(Dialect): "_UCS2": TokenType.INTRODUCER, "_UJIS": TokenType.INTRODUCER, # https://dev.mysql.com/doc/refman/8.0/en/string-literals.html - "N": TokenType.INTRODUCER, - "n": TokenType.INTRODUCER, "_UTF8": TokenType.INTRODUCER, "_UTF16": TokenType.INTRODUCER, "_UTF16LE": TokenType.INTRODUCER, @@ -175,10 +173,10 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, @@ -190,7 +188,7 @@ class MySQL(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, # type: ignore "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -199,12 +197,12 @@ class MySQL(Dialect): } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, # type: ignore TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), } STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, + **parser.Parser.STATEMENT_PARSERS, # type: ignore TokenType.SHOW: lambda self: self._parse_show(), TokenType.SET: lambda self: self._parse_set(), } @@ -429,7 +427,7 @@ class MySQL(Dialect): NULL_ORDERING_SUPPORTED = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index f507513..af3d353 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -39,13 +39,13 @@ class Oracle(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "DECODE": exp.Matches.from_arg_list, } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -60,7 +60,7 @@ class Oracle(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index f276af1..a092cad 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, str_position_sql, ) +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess +DATE_DIFF_FACTOR = { + "MICROSECOND": " * 1000000", + "MILLISECOND": " * 1000", + "SECOND": "", + "MINUTE": " / 60", + "HOUR": " / 3600", + "DAY": " / 86400", +} + def _date_add_sql(kind): def func(self, expression): @@ -34,16 +44,30 @@ def _date_add_sql(kind): return func -def _lateral_sql(self, expression): - this = self.sql(expression, "this") - if isinstance(expression.this, exp.Subquery): - return f"LATERAL{self.sep()}{this}" - alias = expression.args["alias"] - table = alias.name - table = f" {table}" if table else table - columns = self.expressions(alias, key="columns", flat=True) - columns = f" AS {columns}" if columns else "" - return f"LATERAL{self.sep()}{this}{table}{columns}" +def _date_diff_sql(self, expression): + unit = expression.text("unit").upper() + factor = DATE_DIFF_FACTOR.get(unit) + + end = f"CAST({expression.this} AS TIMESTAMP)" + start = f"CAST({expression.expression} AS TIMESTAMP)" + + if factor is not None: + return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)" + + age = f"AGE({end}, {start})" + + if unit == "WEEK": + extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + elif unit == "MONTH": + extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" + elif unit == "QUARTER": + extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" + elif unit == "YEAR": + extract = f"EXTRACT(year FROM {age})" + else: + self.unsupported(f"Unsupported DATEDIFF unit {unit}") + + return f"CAST({extract} AS BIGINT)" def _substring_sql(self, expression): @@ -141,7 +165,7 @@ def _serial_to_generated(expression): def _to_timestamp(args): # TO_TIMESTAMP accepts either a single double argument or (text, text) - if len(args) == 1 and args[0].is_number: + if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) # https://www.postgresql.org/docs/current/functions-formatting.html @@ -211,11 +235,16 @@ class Postgres(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "~~": TokenType.LIKE, + "~~*": TokenType.ILIKE, + "~*": TokenType.IRLIKE, + "~": TokenType.RLIKE, "ALWAYS": TokenType.ALWAYS, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "BY DEFAULT": TokenType.BY_DEFAULT, + "CHARACTER VARYING": TokenType.VARCHAR, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, @@ -233,6 +262,7 @@ class Postgres(Dialect): "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, "UUID": TokenType.UUID, + "CSTRING": TokenType.PSEUDO_TYPE, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } @@ -244,17 +274,16 @@ class Postgres(Dialect): class Parser(parser.Parser): STRICT_CAST = False - LATERAL_FUNCTION_AS_VIEW = True FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", @@ -264,7 +293,7 @@ class Postgres(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ColumnDef: preprocess( [ _auto_increment_to_serial, @@ -274,13 +303,16 @@ class Postgres(Dialect): ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", - exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", + exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), + exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), + exp.JSONBContains: lambda self, e: self.binary(e, "?"), exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), - exp.Lateral: _lateral_sql, + exp.DateDiff: _date_diff_sql, + exp.RegexpLike: lambda self, e: self.binary(e, "~"), + exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, @@ -291,5 +323,7 @@ class Postgres(Dialect): exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, exp.GroupConcat: _string_agg_sql, - exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 1a09037..e16ea1d 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, struct_extract_sql, + timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL from sqlglot.errors import UnsupportedError @@ -38,10 +39,6 @@ def _datatype_sql(self, expression): return sql -def _date_parse_sql(self, expression): - return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')" - - def _explode_to_unnest_sql(self, expression): if isinstance(expression.this, (exp.Explode, exp.Posexplode)): return self.sql( @@ -137,7 +134,7 @@ class Presto(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, @@ -174,7 +171,7 @@ class Presto(Dialect): ROOT_PROPERTIES = {exp.SchemaCommentProperty} TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -184,7 +181,7 @@ class Presto(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", @@ -224,8 +221,8 @@ class Presto(Dialect): exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", - exp.TimeStrToDate: _date_parse_sql, - exp.TimeStrToTime: _date_parse_sql, + exp.TimeStrToDate: timestrtotime_sql, + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 55ed0a6..27dfb93 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -36,7 +36,6 @@ class Redshift(Postgres): "TIMETZ": TokenType.TIMESTAMPTZ, "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, - "SIMILAR TO": TokenType.SIMILAR_TO, } class Generator(Postgres.Generator): diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 75dc9dc..77b09e9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,13 +3,15 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + datestrtodate_sql, format_time_lambda, inline_array_sql, rename_func, + timestrtotime_sql, var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import seq_get +from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -183,7 +185,7 @@ class Snowflake(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPES = ["\\"] + ESCAPES = ["\\", "'"] SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, @@ -206,9 +208,10 @@ class Snowflake(Dialect): CREATE_TRANSIENT = True TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -218,13 +221,14 @@ class Snowflake(Dialect): exp.Matches: rename_func("DECODE"), exp.StrPosition: rename_func("POSITION"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.UnixToTime: _unix_to_time_sql, } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -246,3 +250,47 @@ class Snowflake(Dialect): if not expression.args.get("distinct", False): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) + + def values_sql(self, expression: exp.Values) -> str: + """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted. + + We also want to make sure that after we find matches where we need to unquote a column that we prevent users + from adding quotes to the column by using the `identify` argument when generating the SQL. + """ + alias = expression.args.get("alias") + if alias and alias.args.get("columns"): + expression = expression.transform( + lambda node: exp.Identifier(**{**node.args, "quoted": False}) + if isinstance(node, exp.Identifier) + and isinstance(node.parent, exp.TableAlias) + and node.arg_key == "columns" + else node, + ) + return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) + return super().values_sql(expression) + + def select_sql(self, expression: exp.Select) -> str: + """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also + that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need + to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when + generating the SQL. + + Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the + expression. This might not be true in a case where the same column name can be sourced from another table that can + properly quote but should be true in most cases. + """ + values_expressions = expression.find_all(exp.Values) + values_identifiers = set( + flatten( + v.args.get("alias", exp.Alias()).args.get("columns", []) + for v in values_expressions + ) + ) + if values_identifiers: + expression = expression.transform( + lambda node: exp.Identifier(**{**node.args, "quoted": False}) + if isinstance(node, exp.Identifier) and node in values_identifiers + else node, + ) + return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) + return super().select_sql(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 16083d1..7f05dea 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -76,7 +76,7 @@ class Spark(Hive): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, # type: ignore "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -87,6 +87,16 @@ class Spark(Hive): "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), } + def _parse_add_column(self): + return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() + + def _parse_drop_column(self): + return self._match_text_seq("DROP", "COLUMNS") and self.expression( + exp.Drop, + this=self._parse_schema(), + kind="COLUMNS", + ) + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index bbb752b..a0c4942 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -42,13 +42,13 @@ class SQLite(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "EDITDIST3": exp.Levenshtein.from_arg_list, } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -70,7 +70,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 3519c09..01e6357 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL class StarRocks(MySQL): class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { - **MySQL.Generator.TYPE_MAPPING, + **MySQL.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME", diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 63e7275..36c085f 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -30,7 +30,7 @@ class Tableau(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a552e7b..7f0f2d7 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -224,11 +224,7 @@ class TSQL(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] - QUOTES = [ - (prefix + quote, quote) if prefix else quote - for quote in ["'", '"'] - for prefix in ["", "n", "N"] - ] + QUOTES = ["'", '"'] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -253,7 +249,7 @@ class TSQL(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), @@ -314,7 +310,7 @@ class TSQL(Dialect): class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 758ad1b..fa8bc1b 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -1,3 +1,7 @@ +""" +.. include:: ../posts/sql_diff.md +""" + from __future__ import annotations import typing as t diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index e9ff75b..8a58287 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -29,10 +29,10 @@ class Context: self._table: t.Optional[Table] = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} - self.env = {**(env or {}), "scope": self.row_readers} + self.env = {**ENV, **(env or {}), "scope": self.row_readers} def eval(self, code): - return eval(code, ENV, self.env) + return eval(code, self.env) def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index ad9397e..04dc938 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -127,14 +127,16 @@ def interval(this, unit): ENV = { "exp": exp, # aggs - "SUM": filter_nulls(sum), + "ARRAYAGG": list, "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), "MAX": filter_nulls(max), "MIN": filter_nulls(min), + "SUM": filter_nulls(sum), # scalar functions "ABS": null_if_any(lambda this: abs(this)), "ADD": null_if_any(lambda e, this: e + this), + "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)), "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), "BITWISEAND": null_if_any(lambda this, e: this & e), "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 9f22c45..29848c6 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -394,6 +394,18 @@ def _case_sql(self, expression): return chain +def _lambda_sql(self, e: exp.Lambda) -> str: + names = {e.name.lower() for e in e.expressions} + + e = e.transform( + lambda n: exp.Var(this=n.name) + if isinstance(n, exp.Identifier) and n.name.lower() in names + else n + ) + + return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" + + class Python(Dialect): class Tokenizer(tokens.Tokenizer): ESCAPES = ["\\"] @@ -414,6 +426,7 @@ class Python(Dialect): exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", exp.Is: lambda self, e: self.binary(e, "is"), + exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", exp.Or: lambda self, e: self.binary(e, "or"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index aeed218..711ec4b 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,6 +1,11 @@ +""" +.. include:: ../pdoc/docs/expressions.md +""" + from __future__ import annotations import datetime +import math import numbers import re import typing as t @@ -682,6 +687,10 @@ class CharacterSet(Expression): class With(Expression): arg_types = {"expressions": True, "recursive": False} + @property + def recursive(self) -> bool: + return bool(self.args.get("recursive")) + class WithinGroup(Expression): arg_types = {"this": True, "expression": False} @@ -724,6 +733,18 @@ class ColumnDef(Expression): "this": True, "kind": True, "constraints": False, + "exists": False, + } + + +class AlterColumn(Expression): + arg_types = { + "this": True, + "dtype": False, + "collate": False, + "using": False, + "default": False, + "drop": False, } @@ -877,6 +898,11 @@ class Introducer(Expression): arg_types = {"this": True, "expression": True} +# national char, like n'utf8' +class National(Expression): + pass + + class LoadData(Expression): arg_types = { "this": True, @@ -894,7 +920,7 @@ class Partition(Expression): class Fetch(Expression): - arg_types = {"direction": False, "count": True} + arg_types = {"direction": False, "count": False} class Group(Expression): @@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = { "group": False, "having": False, "qualify": False, - "window": False, + "windows": False, "distribute": False, "sort": False, "cluster": False, @@ -1353,7 +1379,7 @@ class Union(Subqueryable): Example: >>> select("1").union(select("1")).limit(1).sql() - 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' Args: expression (str | int | Expression): the SQL code string to parse. @@ -1889,6 +1915,18 @@ class Select(Subqueryable): **opts, ) + def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + return _apply_list_builder( + *expressions, + instance=self, + arg="windows", + append=append, + into=Window, + dialect=dialect, + copy=copy, + **opts, + ) + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -2140,6 +2178,11 @@ class DataType(Expression): ) +# https://www.postgresql.org/docs/15/datatype-pseudo.html +class PseudoType(Expression): + pass + + class StructKwarg(Expression): arg_types = {"this": True, "expression": True} @@ -2167,18 +2210,26 @@ class Command(Expression): arg_types = {"this": True, "expression": False} -class Transaction(Command): +class Transaction(Expression): arg_types = {"this": False, "modes": False} -class Commit(Command): +class Commit(Expression): arg_types = {"chain": False} -class Rollback(Command): +class Rollback(Expression): arg_types = {"savepoint": False} +class AlterTable(Expression): + arg_types = { + "this": True, + "actions": True, + "exists": False, + } + + # Binary expressions like (ADD a b) class Binary(Expression): arg_types = {"this": True, "expression": True} @@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate): pass +class Slice(Binary): + arg_types = {"this": False, "expression": False} + + class Sub(Binary): pass @@ -2392,7 +2447,7 @@ class TimeUnit(Expression): class Interval(TimeUnit): - arg_types = {"this": True, "unit": False} + arg_types = {"this": False, "unit": False} class IgnoreNulls(Expression): @@ -2730,8 +2785,11 @@ class Initcap(Func): pass -class JSONExtract(Func): - arg_types = {"this": True, "path": True} +class JSONBContains(Binary): + _sql_names = ["JSONB_CONTAINS"] + + +class JSONExtract(Binary, Func): _sql_names = ["JSON_EXTRACT"] @@ -2776,6 +2834,10 @@ class Log10(Func): pass +class LogicalOr(AggFunc): + _sql_names = ["LOGICAL_OR", "BOOL_OR"] + + class Lower(Func): _sql_names = ["LOWER", "LCASE"] @@ -2846,6 +2908,10 @@ class RegexpLike(Func): arg_types = {"this": True, "expression": True, "flag": False} +class RegexpILike(Func): + arg_types = {"this": True, "expression": True, "flag": False} + + class RegexpSplit(Func): arg_types = {"this": True, "expression": True} @@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U ], ) if from_: - update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) + update.set( + "from", + maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), + ) if isinstance(where, Condition): where = Where(this=where) if where: - update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) + update.set( + "where", + maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), + ) return update @@ -3522,7 +3594,7 @@ def paren(expression) -> Paren: return Paren(this=expression) -SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") +SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: @@ -3724,6 +3796,8 @@ def convert(value) -> Expression: return Boolean(this=value) if isinstance(value, str): return Literal.string(value) + if isinstance(value, float) and math.isnan(value): + return NULL if isinstance(value, numbers.Number): return Literal.number(value) if isinstance(value, tuple): @@ -3732,11 +3806,13 @@ def convert(value) -> Expression: return Array(expressions=[convert(v) for v in value]) if isinstance(value, dict): return Map( - keys=[convert(k) for k in value.keys()], + keys=[convert(k) for k in value], values=[convert(v) for v in value.values()], ) if isinstance(value, datetime.datetime): - datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z")) + datetime_literal = Literal.string( + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + ) return TimeStrToTime(this=datetime_literal) if isinstance(value, datetime.date): date_literal = Literal.string(value.strftime("%Y-%m-%d")) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 2b4c575..0c1578a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -361,10 +361,11 @@ class Generator: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) + exists = "IF NOT EXISTS " if expression.args.get("exists") else "" if not constraints: - return f"{column} {kind}" - return f"{column} {kind} {constraints}" + return f"{exists}{column} {kind}" + return f"{exists}{column} {kind} {constraints}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") @@ -549,6 +550,9 @@ class Generator: text = f"{self.identifier_start}{text}{self.identifier_end}" return text + def national_sql(self, expression: exp.National) -> str: + return f"N{self.sql(expression, 'this')}" + def partition_sql(self, expression: exp.Partition) -> str: keys = csv( *[ @@ -633,6 +637,9 @@ class Generator: def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def pseudotype_sql(self, expression: exp.PseudoType) -> str: + return expression.name.upper() + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" @@ -793,19 +800,17 @@ class Generator: if isinstance(expression.this, exp.Subquery): return f"LATERAL {this}" - alias = expression.args["alias"] - table = alias.name - columns = self.expressions(alias, key="columns", flat=True) - if expression.args.get("view"): - table = f" {table}" if table else table + alias = expression.args["alias"] + columns = self.expressions(alias, key="columns", flat=True) + table = f" {alias.name}" if alias.name else "" columns = f" AS {columns}" if columns else "" op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") return f"{op_sql}{self.sep()}{this}{table}{columns}" - table = f" AS {table}" if table else table - columns = f"({columns})" if columns else "" - return f"LATERAL {this}{table}{columns}" + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + return f"LATERAL {this}{alias}" def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") @@ -891,13 +896,15 @@ class Generator: def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("joins", [])], - *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins") or []], + *[self.sql(sql) for sql in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), self.sql(expression, "qualify"), - self.sql(expression, "window"), + self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) + if expression.args.get("windows") + else "", self.sql(expression, "distribute"), self.sql(expression, "sort"), self.sql(expression, "cluster"), @@ -1008,11 +1015,7 @@ class Generator: spec_sql = " " + self.window_spec_sql(spec) if spec else "" alias = self.sql(expression, "alias") - - if expression.arg_key == "window": - this = this = f"{self.seg('WINDOW')} {this} AS" - else: - this = f"{this} OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}" if not partition and not order and not spec and alias: return f"{this} {alias}" @@ -1141,9 +1144,11 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" - return f"INTERVAL {self.sql(expression, 'this')}{unit}" + return f"INTERVAL{this}{unit}" def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") @@ -1245,6 +1250,43 @@ class Generator: savepoint = f" TO {savepoint}" if savepoint else "" return f"ROLLBACK{savepoint}" + def altercolumn_sql(self, expression: exp.AlterColumn) -> str: + this = self.sql(expression, "this") + + dtype = self.sql(expression, "dtype") + if dtype: + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}" + + default = self.sql(expression, "default") + if default: + return f"ALTER COLUMN {this} SET DEFAULT {default}" + + if not expression.args.get("drop"): + self.unsupported("Unsupported ALTER COLUMN syntax") + + return f"ALTER COLUMN {this} DROP DEFAULT" + + def altertable_sql(self, expression: exp.AlterTable) -> str: + actions = expression.args["actions"] + + if isinstance(actions[0], exp.ColumnDef): + actions = self.expressions(expression, "actions", prefix="ADD COLUMN ") + elif isinstance(actions[0], exp.Schema): + actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") + elif isinstance(actions[0], exp.Drop): + actions = self.expressions(expression, "actions") + elif isinstance(actions[0], exp.AlterColumn): + actions = self.sql(actions[0]) + else: + self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") + + exists = " IF EXISTS" if expression.args.get("exists") else "" + return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1327,6 +1369,9 @@ class Generator: def or_sql(self, expression: exp.Or) -> str: return self.connector_sql(expression, "OR") + def slice_sql(self, expression: exp.Slice) -> str: + return self.binary(expression, ":") + def sub_sql(self, expression: exp.Sub) -> str: return self.binary(expression, "-") @@ -1369,6 +1414,7 @@ class Generator: flat: bool = False, indent: bool = True, sep: str = ", ", + prefix: str = "", ) -> str: expressions = expression.args.get(key or "expressions") @@ -1391,11 +1437,13 @@ class Generator: if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") + result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") + result_sqls.append( + f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}" + ) else: - result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sql, skip_first=False) if indent else result_sql diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 33529a5..fc37a54 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression = coerce_type(expression) expression = remove_redundant_casts(expression) + if isinstance(expression, exp.Identifier): + expression.set("quoted", True) + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index de4e011..3b40710 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -129,10 +129,23 @@ def join_condition(join): """ name = join.this.alias_or_name on = (join.args.get("on") or exp.true()).copy() - on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) source_key = [] join_key = [] + def extract_condition(condition): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.true()) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.true()) + # find the join keys # SELECT # FROM x @@ -141,20 +154,30 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) + for condition in on.flatten(): if isinstance(condition, exp.EQ): - left, right = condition.unnest_operands() - left_tables = exp.column_table_names(left) - right_tables = exp.column_table_names(right) - - if name in left_tables and name not in right_tables: - join_key.append(left) - source_key.append(right) - condition.replace(exp.true()) - elif name in right_tables and name not in left_tables: - join_key.append(right) - source_key.append(left) - condition.replace(exp.true()) + extract_condition(condition) + elif normalized(on, dnf=True): + conditions = None + + for condition in on.flatten(): + parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] + if conditions is None: + conditions = parts + else: + temp = [] + for p in parts: + cs = [c for c in conditions if p == c] + + if cs: + temp.append(p) + temp.extend(cs) + conditions = temp + + for condition in conditions: + extract_condition(condition) on = simplify(on) remaining_condition = None if on == exp.true() else on diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 39e252c..2245cc2 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -58,7 +58,9 @@ def eliminate_subqueries(expression): existing_ctes = {} with_ = root.expression.args.get("with") + recursive = False if with_: + recursive = with_.args.get("recursive") for cte in with_.expressions: existing_ctes[cte.this] = cte.alias new_ctes = [] @@ -88,7 +90,7 @@ def eliminate_subqueries(expression): new_ctes.append(new_cte) if new_ctes: - expression.set("with", exp.With(expressions=new_ctes)) + expression.set("with", exp.With(expressions=new_ctes, recursive=recursive)) return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index db538ef..f16f519 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf): left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] - return x + return [ + a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) + ] return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 6819717..72e67d4 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables -from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries RULES = ( @@ -34,7 +33,6 @@ RULES = ( eliminate_ctes, annotate_types, canonicalize, - quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f92e5c3..ba5c8b5 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -27,7 +27,14 @@ def pushdown_predicates(expression): select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources, scope_ref_count) + selected_sources = scope.selected_sources + # a right join can only push down to itself and not the source FROM table + for k, (node, source) in selected_sources.items(): + parent = node.find_ancestor(exp.Join, exp.From) + if isinstance(parent, exp.Join) and parent.side == "RIGHT": + selected_sources = {k: (node, source)} + break + pushdown(where.this, selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): + with_ = source.parent.expression.args.get("with") + if with_ and with_.recursive: + return {} node = source.expression if isinstance(node, exp.Join): - if node.side: + if node.side and node.side != "RIGHT": return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index abd9492..49789ac 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() -# SELECTION TO USE IF SELECTION LIST IS EMPTY +# Selection to use if selection list is empty DEFAULT_SELECTION = alias("1", "_") @@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION) + new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) return removed_indexes @@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove): selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove ] if not new_selections: - new_selections.append(DEFAULT_SELECTION) + new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e6e6dc9..e16a635 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -311,6 +311,9 @@ def _qualify_outputs(scope): alias_ = alias(exp.column(""), alias=selection.name) alias_.set("this", selection) selection = alias_ + elif isinstance(selection, exp.Subquery): + if not selection.alias: + selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias): alias_ = alias(exp.column(""), f"_col_{i}") alias_.set("this", selection) diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py deleted file mode 100644 index 17623cc..0000000 --- a/sqlglot/optimizer/quote_identities.py +++ /dev/null @@ -1,25 +0,0 @@ -from sqlglot import exp - - -def quote_identities(expression): - """ - Rewrite sqlglot AST to ensure all identities are quoted. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x") - >>> quote_identities(expression).sql() - 'SELECT "x"."a" AS "a" FROM "db"."x"' - - Args: - expression (sqlglot.Expression): expression to quote - Returns: - sqlglot.Expression: quoted expression - """ - - def qualify(node): - if isinstance(node, exp.Identifier): - node.set("quoted", True) - return node - - return expression.transform(qualify, copy=False) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 18848f3..6125e4e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -511,9 +511,20 @@ def _traverse_union(scope): def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} + is_cte = scope_type == ScopeType.CTE for derived_table in derived_tables: - top = None + recursive_scope = None + + # if the scope is a recursive cte, it must be in the form of + # base_case UNION recursive. thus the recursive scope is the first + # section of the union. + if is_cte and scope.expression.args["with"].recursive: + union = derived_table.this + + if isinstance(union, exp.Union): + recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, @@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope - top = child_scope + # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - sources[derived_table.alias] = child_scope - if scope_type == ScopeType.CTE: - scope.cte_scopes.append(top) + alias = derived_table.alias + sources[alias] = child_scope + + if recursive_scope: + child_scope.add_source(alias, recursive_scope) + + # append the final child_scope yielded + if is_cte: + scope.cte_scopes.append(child_scope) else: - scope.derived_table_scopes.append(top) + scope.derived_table_scopes.append(child_scope) + scope.sources.update(sources) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 2046917..8d78294 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -16,7 +16,7 @@ def unnest_subqueries(expression): >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' Args: expression (sqlglot.Expression): expression to unnest @@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence): table_alias = _alias(sequence) keys = [] - # for all external columns in the where statement, - # split out the relevant data to convert it into a join + # for all external columns in the where statement, find the relevant predicate + # keys to convert it into a join for column in external_columns: if column.find_ancestor(exp.Where) is not where: return @@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence): if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): return + is_subquery_projection = any( + node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) + ) + value = select.selects[0] key_aliases = {} group_by = [] @@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence): parent_predicate = select.find_ancestor(exp.Predicate) # if the value of the subquery is not an agg or a key, we need to collect it into an array - # so that it can be grouped + # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. + agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg if not value.find(exp.AggFunc) and value.this not in group_by: - select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) + select.select( + exp.alias_(agg_func(this=value.this), value.alias, quoted=False), + append=False, + copy=False, + ) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if isinstance(parent_predicate, exp.Exists) or key != value.this: select.select(f"{key} AS {alias}", copy=False) else: - select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) alias = exp.column(value.alias, table_alias) other = _other_operand(parent_predicate) @@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence): f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", ) else: + if is_subquery_projection: + alias = exp.alias_(alias, select.parent.alias) select.parent.replace(alias) for key, column, predicate in keys: predicate.replace(exp.true()) nested = exp.column(key_aliases[key], table_alias) + if is_subquery_projection: + key.replace(nested) + continue + if key in group_by: key.replace(nested) parent_predicate = _replace( diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 29bc9c0..308f363 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, seq_get +from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -117,6 +117,7 @@ class Parser(metaclass=_Parser): TokenType.GEOMETRY, TokenType.HLLSKETCH, TokenType.HSTORE, + TokenType.PSEUDO_TYPE, TokenType.SUPER, TokenType.SERIAL, TokenType.SMALLSERIAL, @@ -153,6 +154,7 @@ class Parser(metaclass=_Parser): TokenType.CACHE, TokenType.CASCADE, TokenType.COLLATE, + TokenType.COLUMN, TokenType.COMMAND, TokenType.COMMIT, TokenType.COMPOUND, @@ -169,6 +171,7 @@ class Parser(metaclass=_Parser): TokenType.ESCAPE, TokenType.FALSE, TokenType.FIRST, + TokenType.FILTER, TokenType.FOLLOWING, TokenType.FORMAT, TokenType.FUNCTION, @@ -188,6 +191,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, + TokenType.OFFSET, TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, @@ -222,12 +226,18 @@ class Parser(metaclass=_Parser): TokenType.PROPERTIES, TokenType.PROCEDURE, TokenType.VOLATILE, + TokenType.WINDOW, *SUBQUERY_PREDICATES, *TYPE_TOKENS, *NO_PAREN_FUNCTIONS, } - TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { + TokenType.APPLY, + TokenType.NATURAL, + TokenType.OFFSET, + TokenType.WINDOW, + } UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} @@ -257,6 +267,7 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, + TokenType.WINDOW, *TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -351,22 +362,27 @@ class Parser(metaclass=_Parser): TokenType.ARROW: lambda self, this, path: self.expression( exp.JSONExtract, this=this, - path=path, + expression=path, ), TokenType.DARROW: lambda self, this, path: self.expression( exp.JSONExtractScalar, this=this, - path=path, + expression=path, ), TokenType.HASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtract, this=this, - path=path, + expression=path, ), TokenType.DHASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtractScalar, this=this, - path=path, + expression=path, + ), + TokenType.PLACEHOLDER: lambda self, this, key: self.expression( + exp.JSONBContains, + this=this, + expression=key, ), } @@ -392,25 +408,27 @@ class Parser(metaclass=_Parser): exp.Ordered: lambda self: self._parse_ordered(), exp.Having: lambda self: self._parse_having(), exp.With: lambda self: self._parse_with(), + exp.Window: lambda self: self._parse_named_window(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } STATEMENT_PARSERS = { + TokenType.ALTER: lambda self: self._parse_alter(), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), + TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.INSERT: lambda self: self._parse_insert(), TokenType.LOAD_DATA: lambda self: self._parse_load_data(), - TokenType.UPDATE: lambda self: self._parse_update(), - TokenType.DELETE: lambda self: self._parse_delete(), - TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.UPDATE: lambda self: self._parse_update(), TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), - TokenType.BEGIN: lambda self: self._parse_transaction(), - TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), - TokenType.END: lambda self: self._parse_commit_or_rollback(), - TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), - TokenType.MERGE: lambda self: self._parse_merge(), } UNARY_PARSERS = { @@ -441,6 +459,7 @@ class Parser(metaclass=_Parser): TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.NATIONAL: lambda self, token: self._parse_national(token), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -454,6 +473,9 @@ class Parser(metaclass=_Parser): TokenType.ILIKE: lambda self, this: self._parse_escape( self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) ), + TokenType.IRLIKE: lambda self, this: self.expression( + exp.RegexpILike, this=this, expression=self._parse_bitwise() + ), TokenType.RLIKE: lambda self, this: self.expression( exp.RegexpLike, this=this, expression=self._parse_bitwise() ), @@ -535,8 +557,7 @@ class Parser(metaclass=_Parser): "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) - and self._parse_window(self._parse_id_var(), alias=True), + "windows": lambda self: self._parse_window_clause(), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), @@ -551,18 +572,18 @@ class Parser(metaclass=_Parser): MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) CREATABLES = { - TokenType.TABLE, - TokenType.VIEW, + TokenType.COLUMN, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE, TokenType.SCHEMA, + TokenType.TABLE, + TokenType.VIEW, } TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} STRICT_CAST = True - LATERAL_FUNCTION_AS_VIEW = False __slots__ = ( "error_level", @@ -782,13 +803,16 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self): + def _parse_drop(self, default_kind=None): temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: - self.raise_error(f"Expected {self.CREATABLES}") - return + if default_kind: + kind = default_kind + else: + self.raise_error(f"Expected {self.CREATABLES}") + return return self.expression( exp.Drop, @@ -876,7 +900,7 @@ class Parser(metaclass=_Parser): ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) if assignment: - key = self._parse_var() or self._parse_string() + key = self._parse_var_or_string() self._match(TokenType.EQ) return self.expression(exp.Property, this=key, value=self._parse_column()) @@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser): elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) + this = self._parse_set_operations(this) self._match_r_paren() - this = self._parse_subquery(this) + # early return so that subquery unions aren't parsed again + # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 + # Union ALL should be a property of the top select node, not the subquery + return self._parse_subquery(this) elif self._match(TokenType.VALUES): + if self._curr.token_type == TokenType.L_PAREN: + # We don't consume the left paren because it's consumed in _parse_value + expressions = self._parse_csv(self._parse_value) + else: + # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. + # Source: https://prestodb.io/docs/current/sql/values.html + expressions = self._parse_csv( + lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) + ) + this = self.expression( exp.Values, - expressions=self._parse_csv(self._parse_value), + expressions=expressions, alias=self._parse_table_alias(), ) else: this = None - return self._parse_set_operations(this) if this else None + return self._parse_set_operations(this) def _parse_with(self, skip_with_token=False): if not skip_with_token and not self._match(TokenType.WITH): @@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser): alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS ) - columns = None if self._match(TokenType.L_PAREN): - columns = self._parse_csv(lambda: self._parse_id_var(any_token)) + columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var())) self._match_r_paren() + else: + columns = None if not alias and not columns: return None @@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) - columns = None - table_alias = None - if view or self.LATERAL_FUNCTION_AS_VIEW: - table_alias = self._parse_id_var(any_token=False) - if self._match(TokenType.ALIAS): - columns = self._parse_csv(self._parse_id_var) + if view: + table = self._parse_id_var(any_token=False) + columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] + table_alias = self.expression(exp.TableAlias, this=table, columns=columns) else: - self._match(TokenType.ALIAS) - table_alias = self._parse_id_var(any_token=False) - - if self._match(TokenType.L_PAREN): - columns = self._parse_csv(self._parse_id_var) - self._match_r_paren() + table_alias = self._parse_table_alias() expression = self.expression( exp.Lateral, this=this, view=view, outer=outer, - alias=self.expression(exp.TableAlias, this=table_alias, columns=columns), + alias=table_alias, ) if outer_apply or cross_apply: @@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser): if negate: this = self.expression(exp.Not, this=this) + if self._match(TokenType.IS): + this = self._parse_is(this) + return this def _parse_is(self, this): @@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser): return None type_token = self._prev.token_type + + if type_token == TokenType.PSEUDO_TYPE: + return self.expression(exp.PseudoType, this=self._prev.text) + nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token == TokenType.STRUCT expressions = None @@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser): if value is None: value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + elif type_token == TokenType.INTERVAL: + value = self.expression(exp.Interval, unit=self._parse_var()) if maybe_func and check_func: index2 = self._index @@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser): def _parse_primary(self): if self._match_set(self.PRIMARY_PARSERS): - return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + token_type = self._prev.token_type + primary = self.PRIMARY_PARSERS[token_type](self, self._prev) + + if token_type == TokenType.STRING: + expressions = [primary] + while self._match(TokenType.STRING): + expressions.append(exp.Literal.string(self._prev.text)) + if len(expressions) > 1: + return self.expression(exp.Concat, expressions=expressions) + return primary if self._match_pair(TokenType.DOT, TokenType.NUMBER): return exp.Literal.number(f"0.{self._prev.text}") @@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=token.text) + def _parse_national(self, token): + return self.expression(exp.National, this=exp.Literal.string(token.text)) + def _parse_session_parameter(self): kind = None this = self._parse_id_var() or self._parse_primary() @@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_id_var) - self._match(TokenType.R_PAREN) + + if not self._match(TokenType.R_PAREN): + self._retreat(index) else: expressions = [self._parse_id_var()] @@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser): exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) ) else: - this = self._parse_conjunction() + this = self._parse_select_or_expression() if self._match(TokenType.IGNORE_NULLS): this = self.expression(exp.IgnoreNulls, this=this) else: self._match(TokenType.RESPECT_NULLS) - return self._parse_alias(self._parse_limit(self._parse_order(this))) + return self._parse_limit(self._parse_order(this)) def _parse_schema(self, this=None): index = self._index @@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser): return this args = self._parse_csv( - lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) + lambda: self._parse_constraint() + or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.ENCODE): kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.NULL): @@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_BRACKET): return this - expressions = self._parse_csv(self._parse_conjunction) + if self._match(TokenType.COLON): + expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] + else: + expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction())) if not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) @@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser): this.comments = self._prev_comments return self._parse_bracket(this) + def _parse_slice(self, this): + if self._match(TokenType.COLON): + return self.expression(exp.Slice, this=this, expression=self._parse_conjunction()) + return this + def _parse_case(self): ifs = [] default = None @@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser): collation=collation, ) + def _parse_window_clause(self): + return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) + + def _parse_named_window(self): + return self._parse_window(self._parse_id_var(), alias=True) + def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): where = self._parse_wrapped(self._parse_where) @@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser): if identifier: return identifier - if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: - self._advance() - elif not self._match_set(tokens or self.ID_VAR_TOKENS): - return None - return exp.Identifier(this=self._prev.text, quoted=False) + if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): + return exp.Identifier(this=self._prev.text, quoted=False) + return None def _parse_string(self): if self._match(TokenType.STRING): @@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self): - if self._match(TokenType.VAR): + def _parse_var(self, any_token=False): + if (any_token and self._advance_any()) or self._match(TokenType.VAR): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() + def _advance_any(self): + if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: + self._advance() + return self._prev + return None + def _parse_var_or_string(self): return self._parse_var() or self._parse_string() @@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.PLACEHOLDER): return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): - self._advance() - return self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set((TokenType.NUMBER, TokenType.VAR)): + return self.expression(exp.Placeholder, this=self._prev.text) + self._advance(-1) return None def _parse_except(self): @@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser): return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Commit, chain=chain) + def _parse_add_column(self): + if not self._match_text_seq("ADD"): + return None + + self._match(TokenType.COLUMN) + exists_column = self._parse_exists(not_=True) + expression = self._parse_column_def(self._parse_field(any_token=True)) + expression.set("exists", exists_column) + return expression + + def _parse_drop_column(self): + return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + + def _parse_alter(self): + if not self._match(TokenType.TABLE): + return None + + exists = self._parse_exists() + this = self._parse_table(schema=True) + + actions = None + if self._match_text_seq("ADD", advance=False): + actions = self._parse_csv(self._parse_add_column) + elif self._match_text_seq("DROP", advance=False): + actions = self._parse_csv(self._parse_drop_column) + elif self._match_text_seq("ALTER"): + self._match(TokenType.COLUMN) + column = self._parse_field(any_token=True) + + if self._match_pair(TokenType.DROP, TokenType.DEFAULT): + actions = self.expression(exp.AlterColumn, this=column, drop=True) + elif self._match_pair(TokenType.SET, TokenType.DEFAULT): + actions = self.expression( + exp.AlterColumn, this=column, default=self._parse_conjunction() + ) + else: + self._match_text_seq("SET", "DATA") + actions = self.expression( + exp.AlterColumn, + this=column, + dtype=self._match_text_seq("TYPE") and self._parse_types(), + collate=self._match(TokenType.COLLATE) and self._parse_term(), + using=self._match(TokenType.USING) and self._parse_conjunction(), + ) + + actions = ensure_list(actions) + return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) + def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) if parser: @@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser): return True return False - def _match_text_seq(self, *texts): + def _match_text_seq(self, *texts, advance=True): index = self._index for text in texts: if self._curr and self._curr.text.upper() == text: @@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser): else: self._retreat(index) return False + + if not advance: + self._retreat(index) + return True def _replace_columns_with_dots(self, this): diff --git a/sqlglot/schema.py b/sqlglot/schema.py index c223ee0..d9a4004 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): super().__init__(schema) self.visible = visible or {} self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType] = { - "STR": exp.DataType.build("text"), - } + self._type_mapping_cache: t.Dict[str, exp.DataType] = {} @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index b25ef8d..0efa7d0 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -48,6 +48,7 @@ class TokenType(AutoName): DOLLAR = auto() PARAMETER = auto() SESSION_PARAMETER = auto() + NATIONAL = auto() BLOCK_START = auto() BLOCK_END = auto() @@ -111,6 +112,7 @@ class TokenType(AutoName): # keywords ALIAS = auto() + ALTER = auto() ALWAYS = auto() ALL = auto() ANTI = auto() @@ -196,6 +198,7 @@ class TokenType(AutoName): INTERVAL = auto() INTO = auto() INTRODUCER = auto() + IRLIKE = auto() IS = auto() ISNULL = auto() JOIN = auto() @@ -241,6 +244,7 @@ class TokenType(AutoName): PRIMARY_KEY = auto() PROCEDURE = auto() PROPERTIES = auto() + PSEUDO_TYPE = auto() QUALIFY = auto() QUOTE = auto() RANGE = auto() @@ -346,7 +350,11 @@ class _Tokenizer(type): def __new__(cls, clsname, bases, attrs): # type: ignore klass = super().__new__(cls, clsname, bases, attrs) - klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) + klass._QUOTES = { + f"{prefix}{s}": e + for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items() + for prefix in (("",) if s[0].isalpha() else ("", "n", "N")) + } klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) @@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer): "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, "COLLATE": TokenType.COLLATE, + "COLUMN": TokenType.COLUMN, "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, "COMPOUND": TokenType.COMPOUND, @@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer): "SEMI": TokenType.SEMI, "SET": TokenType.SET, "SHOW": TokenType.SHOW, + "SIMILAR TO": TokenType.SIMILAR_TO, "SOME": TokenType.SOME, "SORTKEY": TokenType.SORTKEY, "SORT BY": TokenType.SORT_BY, @@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer): "VOLATILE": TokenType.VOLATILE, "WHEN": TokenType.WHEN, "WHERE": TokenType.WHERE, + "WINDOW": TokenType.WINDOW, "WITH": TokenType.WITH, "WITH TIME ZONE": TokenType.WITH_TIME_ZONE, "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, @@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer): "VARCHAR2": TokenType.VARCHAR, "NVARCHAR": TokenType.NVARCHAR, "NVARCHAR2": TokenType.NVARCHAR, + "STR": TokenType.TEXT, "STRING": TokenType.TEXT, "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, @@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer): "UNIQUE": TokenType.UNIQUE, "STRUCT": TokenType.STRUCT, "VARIANT": TokenType.VARIANT, - "ALTER": TokenType.COMMAND, + "ALTER": TokenType.ALTER, + "ALTER AGGREGATE": TokenType.COMMAND, + "ALTER DEFAULT": TokenType.COMMAND, + "ALTER DOMAIN": TokenType.COMMAND, + "ALTER ROLE": TokenType.COMMAND, + "ALTER RULE": TokenType.COMMAND, + "ALTER SEQUENCE": TokenType.COMMAND, + "ALTER TYPE": TokenType.COMMAND, + "ALTER USER": TokenType.COMMAND, + "ALTER VIEW": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND, "CALL": TokenType.COMMAND, "EXPLAIN": TokenType.COMMAND, @@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer): text = self._extract_string(quote_end) text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore text = text.replace("\\\\", "\\") if self._replace_backslash else text - self._add(TokenType.STRING, text) + self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. |