From d3bb537b2b73788ba06bf4158f473ecc5bb556cc Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Tue, 17 Jan 2023 11:32:16 +0100 Subject: Merging upstream version 10.5.2. Signed-off-by: Daniel Baumann --- sqlglot/__init__.py | 13 +- sqlglot/dialects/bigquery.py | 7 +- sqlglot/dialects/clickhouse.py | 35 +- sqlglot/dialects/dialect.py | 17 + sqlglot/dialects/hive.py | 23 +- sqlglot/dialects/oracle.py | 3 +- sqlglot/dialects/postgres.py | 21 +- sqlglot/dialects/snowflake.py | 8 +- sqlglot/dialects/tsql.py | 22 +- sqlglot/expressions.py | 117 +++++- sqlglot/generator.py | 69 +++- sqlglot/helper.py | 20 +- sqlglot/optimizer/annotate_types.py | 2 +- sqlglot/optimizer/eliminate_joins.py | 4 +- sqlglot/optimizer/merge_subqueries.py | 54 ++- sqlglot/optimizer/optimizer.py | 6 +- sqlglot/optimizer/pushdown_projections.py | 4 + sqlglot/optimizer/qualify_columns.py | 4 +- sqlglot/optimizer/simplify.py | 19 +- sqlglot/optimizer/unnest_subqueries.py | 38 +- sqlglot/parser.py | 652 +++++++++++++++++++++--------- sqlglot/schema.py | 45 ++- sqlglot/serde.py | 67 +++ sqlglot/tokens.py | 19 +- sqlglot/transforms.py | 24 ++ sqlglot/trie.py | 2 +- 26 files changed, 984 insertions(+), 311 deletions(-) create mode 100644 sqlglot/serde.py (limited to 'sqlglot') diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index 04c3195..87fa081 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -32,7 +32,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.4.2" +__version__ = "10.5.2" pretty = False @@ -60,9 +60,9 @@ def parse( def parse_one( sql: str, read: t.Optional[str | Dialect] = None, - into: t.Optional[Expression | str] = None, + into: t.Optional[t.Type[Expression] | str] = None, **opts, -) -> t.Optional[Expression]: +) -> Expression: """ Parses the given SQL string and returns a syntax tree for the first parsed SQL statement. @@ -83,7 +83,12 @@ def parse_one( else: result = dialect.parse(sql, **opts) - return result[0] if result else None + for expression in result: + if not expression: + raise ParseError(f"No expression was parsed from '{sql}'") + return expression + else: + raise ParseError(f"No expression was parsed from '{sql}'") def transpile( diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index d10cc54..f0089e1 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -2,7 +2,7 @@ from __future__ import annotations -from sqlglot import exp, generator, parser, tokens +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, datestrtodate_sql, @@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind): def _derived_table_values_to_unnest(self, expression): if not isinstance(expression.unnest().parent, exp.From): + expression = transforms.remove_precision_parameterized_types(expression) return self.values_sql(expression) - rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)] + rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)] structs = [] for row in rows: aliases = [ @@ -118,6 +119,7 @@ class BigQuery(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, "CURRENT_TIME": TokenType.CURRENT_TIME, + "DECLARE": TokenType.COMMAND, "GEOGRAPHY": TokenType.GEOGRAPHY, "FLOAT64": TokenType.DOUBLE, "INT64": TokenType.BIGINT, @@ -166,6 +168,7 @@ class BigQuery(Dialect): class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 7136340..04d46d2 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql from sqlglot.parser import parse_var_map @@ -22,6 +24,7 @@ class ClickHouse(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ASOF": TokenType.ASOF, + "GLOBAL": TokenType.GLOBAL, "DATETIME64": TokenType.DATETIME, "FINAL": TokenType.FINAL, "FLOAT32": TokenType.FLOAT, @@ -37,14 +40,32 @@ class ClickHouse(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, # type: ignore "MAP": parse_var_map, + "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params), + "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args), + "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args), + } + + RANGE_PARSERS = { + **parser.Parser.RANGE_PARSERS, + TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN) + and self._parse_in(this, is_global=True), } JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore - def _parse_table(self, schema=False): - this = super()._parse_table(schema) + def _parse_in( + self, this: t.Optional[exp.Expression], is_global: bool = False + ) -> exp.Expression: + this = super()._parse_in(this) + this.set("is_global", is_global) + return this + + def _parse_table( + self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: + this = super()._parse_table(schema=schema, alias_tokens=alias_tokens) if self._match(TokenType.FINAL): this = self.expression(exp.Final, this=this) @@ -76,6 +97,16 @@ class ClickHouse(Dialect): exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)), exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), + exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}", + exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}", + exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}", } EXPLICIT_UNION = True + + def _param_args_sql( + self, expression: exp.Expression, params_name: str, args_name: str + ) -> str: + params = self.format_args(self.expressions(expression, params_name)) + args = self.format_args(self.expressions(expression, args_name)) + return f"({params})({args})" diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index e788852..1c840da 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + collation = self.sql(expression, "collation") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific + if not remove_chars and not collation: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + collation = f" COLLATE {collation}" if collation else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 088555c..ead13b1 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -175,14 +175,6 @@ class Hive(Dialect): ESCAPES = ["\\"] ENCODE = "utf-8" - NUMERIC_LITERALS = { - "L": "BIGINT", - "S": "SMALLINT", - "Y": "TINYINT", - "D": "DOUBLE", - "F": "FLOAT", - "BD": "DECIMAL", - } KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "ADD ARCHIVE": TokenType.COMMAND, @@ -191,9 +183,21 @@ class Hive(Dialect): "ADD FILES": TokenType.COMMAND, "ADD JAR": TokenType.COMMAND, "ADD JARS": TokenType.COMMAND, + "MSCK REPAIR": TokenType.COMMAND, "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES, } + NUMERIC_LITERALS = { + "L": "BIGINT", + "S": "SMALLINT", + "Y": "TINYINT", + "D": "DOUBLE", + "F": "FLOAT", + "BD": "DECIMAL", + } + + IDENTIFIER_CAN_START_WITH_DIGIT = True + class Parser(parser.Parser): STRICT_CAST = False @@ -315,6 +319,7 @@ class Hive(Dialect): exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}", exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"), exp.NumberToStr: rename_func("FORMAT_NUMBER"), + exp.LastDateOfMonth: rename_func("LAST_DAY"), } WITH_PROPERTIES = {exp.Property} @@ -342,4 +347,6 @@ class Hive(Dialect): and not expression.expressions ): expression = exp.DataType.build("text") + elif expression.this in exp.DataType.TEMPORAL_TYPES: + expression = exp.DataType.build(expression.this) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index af3d353..86caa6b 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -1,7 +1,7 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens, transforms -from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func +from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql from sqlglot.helper import csv from sqlglot.tokens import TokenType @@ -64,6 +64,7 @@ class Oracle(Dialect): **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, + exp.Trim: trim_sql, exp.Matches: rename_func("DECODE"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index a092cad..f3fec31 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, no_trycast_sql, str_position_sql, + trim_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -81,23 +82,6 @@ def _substring_sql(self, expression): return f"SUBSTRING({this}{from_part}{for_part})" -def _trim_sql(self, expression): - target = self.sql(expression, "this") - trim_type = self.sql(expression, "position") - remove_chars = self.sql(expression, "expression") - collation = self.sql(expression, "collation") - - # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific - if not remove_chars and not collation: - return self.trim_sql(expression) - - trim_type = f"{trim_type} " if trim_type else "" - remove_chars = f"{remove_chars} " if remove_chars else "" - from_part = "FROM " if trim_type or remove_chars else "" - collation = f" COLLATE {collation}" if collation else "" - return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" - - def _string_agg_sql(self, expression): expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -248,7 +232,6 @@ class Postgres(Dialect): "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, - "DOUBLE PRECISION": TokenType.DOUBLE, "GENERATED": TokenType.GENERATED, "GRANT": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, @@ -318,7 +301,7 @@ class Postgres(Dialect): exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, - exp.Trim: _trim_sql, + exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 77b09e9..24d3bdf 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -195,7 +195,6 @@ class Snowflake(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, "QUALIFY": TokenType.QUALIFY, - "DOUBLE PRECISION": TokenType.DOUBLE, "TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ, "TIMESTAMP_NTZ": TokenType.TIMESTAMP, "TIMESTAMP_TZ": TokenType.TIMESTAMPTZ, @@ -294,3 +293,10 @@ class Snowflake(Dialect): ) return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) return super().select_sql(expression) + + def describe_sql(self, expression: exp.Describe) -> str: + # Default to table if kind is unknown + kind_value = expression.args.get("kind") or "TABLE" + kind = f" {kind_value}" if kind_value else "" + this = f" {self.sql(expression, 'this')}" + return f"DESCRIBE{kind}{this}" diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 7f0f2d7..465f534 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -75,6 +75,20 @@ def _parse_format(args): ) +def _parse_eomonth(args): + date = seq_get(args, 0) + month_lag = seq_get(args, 1) + unit = DATE_DELTA_INTERVAL.get("month") + + if month_lag is None: + return exp.LastDateOfMonth(this=date) + + # Remove month lag argument in parser as its compared with the number of arguments of the resulting class + args.remove(month_lag) + + return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit)) + + def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" @@ -256,12 +270,14 @@ class TSQL(Dialect): "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": _format_time_lambda(exp.TimeToStr), - "GETDATE": exp.CurrentDate.from_arg_list, + "GETDATE": exp.CurrentTimestamp.from_arg_list, + "SYSDATETIME": exp.CurrentTimestamp.from_arg_list, "IIF": exp.If.from_arg_list, "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, "FORMAT": _parse_format, + "EOMONTH": _parse_eomonth, } VAR_LENGTH_DATATYPES = { @@ -271,6 +287,9 @@ class TSQL(Dialect): DataType.Type.NCHAR, } + # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table + TABLE_PREFIX_TOKENS = {TokenType.HASH} + def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) @@ -323,6 +342,7 @@ class TSQL(Dialect): exp.DateAdd: generate_date_delta_with_unit_sql, exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), + exp.CurrentTimestamp: rename_func("GETDATE"), exp.If: rename_func("IIF"), exp.NumberToStr: _format_sql, exp.TimeToStr: _format_sql, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 711ec4b..d093e29 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -22,6 +22,7 @@ from sqlglot.helper import ( split_num_words, subclasses, ) +from sqlglot.tokens import Token if t.TYPE_CHECKING: from sqlglot.dialects.dialect import Dialect @@ -457,6 +458,23 @@ class Expression(metaclass=_Expression): assert isinstance(self, type_) return self + def dump(self): + """ + Dump this Expression to a JSON-serializable dict. + """ + from sqlglot.serde import dump + + return dump(self) + + @classmethod + def load(cls, obj): + """ + Load a dict (as returned by `Expression.dump`) into an Expression instance. + """ + from sqlglot.serde import load + + return load(obj) + class Condition(Expression): def and_(self, *expressions, dialect=None, **opts): @@ -631,11 +649,15 @@ class Create(Expression): "replace": False, "unique": False, "materialized": False, + "data": False, + "statistics": False, + "no_primary_index": False, + "indexes": False, } class Describe(Expression): - pass + arg_types = {"this": True, "kind": False} class Set(Expression): @@ -731,7 +753,7 @@ class Column(Condition): class ColumnDef(Expression): arg_types = { "this": True, - "kind": True, + "kind": False, "constraints": False, "exists": False, } @@ -879,7 +901,15 @@ class Identifier(Expression): class Index(Expression): - arg_types = {"this": False, "table": False, "where": False, "columns": False} + arg_types = { + "this": False, + "table": False, + "where": False, + "columns": False, + "unique": False, + "primary": False, + "amp": False, # teradata + } class Insert(Expression): @@ -1361,6 +1391,7 @@ class Table(Expression): "laterals": False, "joins": False, "pivots": False, + "hints": False, } @@ -1818,7 +1849,12 @@ class Select(Subqueryable): join.this.replace(join.this.subquery()) if join_type: + natural: t.Optional[Token] + side: t.Optional[Token] + kind: t.Optional[Token] + natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore + if natural: join.set("natural", True) if side: @@ -2111,6 +2147,7 @@ class DataType(Expression): JSON = auto() JSONB = auto() INTERVAL = auto() + TIME = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -2171,11 +2208,24 @@ class DataType(Expression): } @classmethod - def build(cls, dtype, **kwargs) -> DataType: - return DataType( - this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()], - **kwargs, - ) + def build( + cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs + ) -> DataType: + from sqlglot import parse_one + + if isinstance(dtype, str): + data_type_exp: t.Optional[Expression] + if dtype.upper() in cls.Type.__members__: + data_type_exp = DataType(this=DataType.Type[dtype.upper()]) + else: + data_type_exp = parse_one(dtype, read=dialect, into=DataType) + if data_type_exp is None: + raise ValueError(f"Unparsable data type value: {dtype}") + elif isinstance(dtype, DataType.Type): + data_type_exp = DataType(this=dtype) + else: + raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type") + return DataType(**{**data_type_exp.args, **kwargs}) # https://www.postgresql.org/docs/15/datatype-pseudo.html @@ -2429,6 +2479,7 @@ class In(Predicate): "query": False, "unnest": False, "field": False, + "is_global": False, } @@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit): arg_types = {"this": True, "unit": True, "zone": False} +class LastDateOfMonth(Func): + pass + + class Extract(Func): arg_types = {"this": True, "expression": True} @@ -2815,7 +2870,13 @@ class Length(Func): class Levenshtein(Func): - arg_types = {"this": True, "expression": False} + arg_types = { + "this": True, + "expression": False, + "ins_cost": False, + "del_cost": False, + "sub_cost": False, + } class Ln(Func): @@ -2890,6 +2951,16 @@ class Quantile(AggFunc): arg_types = {"this": True, "quantile": True} +# Clickhouse-specific: +# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles +class Quantiles(AggFunc): + arg_types = {"parameters": True, "expressions": True} + + +class QuantileIf(AggFunc): + arg_types = {"parameters": True, "expressions": True} + + class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False} @@ -2962,8 +3033,10 @@ class StrToTime(Func): arg_types = {"this": True, "format": True} +# Spark allows unix_timestamp() +# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html class StrToUnix(Func): - arg_types = {"this": True, "format": True} + arg_types = {"this": False, "format": False} class NumberToStr(Func): @@ -3131,7 +3204,7 @@ def maybe_parse( dialect=None, prefix=None, **opts, -) -> t.Optional[Expression]: +) -> Expression: """Gracefully handle a possible string or expression. Example: @@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]: if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for a table: {type(sql_path)}") - catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)] + catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3)) return Table(this=table_name, db=db, catalog=catalog, **kwargs) -def to_column(sql_path: str, **kwargs) -> Column: +def to_column(sql_path: str | Column, **kwargs) -> Column: """ Create a column from a `[table].[column]` sql path. Schema is optional. @@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column: return sql_path if not isinstance(sql_path, str): raise ValueError(f"Invalid type provided for column: {type(sql_path)}") - table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)] + table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2)) return Column(this=column_name, table=table_name, **kwargs) @@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table: def values( values: t.Iterable[t.Tuple[t.Any, ...]], alias: t.Optional[str] = None, - columns: t.Optional[t.Iterable[str]] = None, + columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None, ) -> Values: """Build VALUES statement. @@ -3759,7 +3832,10 @@ def values( Args: values: values statements that will be converted to SQL alias: optional alias - columns: Optional list of ordered column names. An alias is required when providing column names. + columns: Optional list of ordered column names or ordered dictionary of column names to types. + If either are provided then an alias is also required. + If a dictionary is provided then the first column of the values will be casted to the expected type + in order to help with type inference. Returns: Values: the Values expression object @@ -3771,8 +3847,15 @@ def values( if columns else TableAlias(this=to_identifier(alias) if alias else None) ) + expressions = [convert(tup) for tup in values] + if columns and isinstance(columns, dict): + types = list(columns.values()) + expressions[0].set( + "expressions", + [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)], + ) return Values( - expressions=[convert(tup) for tup in values], + expressions=expressions, alias=table_alias, ) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 0c1578a..3935133 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -50,7 +50,7 @@ class Generator: The default is on the smaller end because the length only represents a segment and not the true line length. Default: 80 - comments: Whether or not to preserve comments in the ouput SQL code. + comments: Whether or not to preserve comments in the output SQL code. Default: True """ @@ -236,7 +236,10 @@ class Generator: return sql sep = "\n" if self.pretty else " " - comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) + comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment) + + if not comments: + return sql if isinstance(expression, self.WITH_SEPARATED_COMMENTS): return f"{comments}{self.sep()}{sql}" @@ -362,10 +365,10 @@ class Generator: kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) exists = "IF NOT EXISTS " if expression.args.get("exists") else "" + kind = f" {kind}" if kind else "" + constraints = f" {constraints}" if constraints else "" - if not constraints: - return f"{exists}{column} {kind}" - return f"{exists}{column} {kind} {constraints}" + return f"{exists}{column}{kind}{constraints}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") @@ -416,7 +419,7 @@ class Generator: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() expression_sql = self.sql(expression, "expression") - expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" + expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" @@ -427,6 +430,40 @@ class Generator: unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") + data = expression.args.get("data") + if data is None: + data = "" + elif data: + data = " WITH DATA" + else: + data = " WITH NO DATA" + statistics = expression.args.get("statistics") + if statistics is None: + statistics = "" + elif statistics: + statistics = " AND STATISTICS" + else: + statistics = " AND NO STATISTICS" + no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else "" + + indexes = expression.args.get("indexes") + index_sql = "" + if indexes is not None: + indexes_sql = [] + for index in indexes: + ind_unique = " UNIQUE" if index.args.get("unique") else "" + ind_primary = " PRIMARY" if index.args.get("primary") else "" + ind_amp = " AMP" if index.args.get("amp") else "" + ind_name = f" {index.name}" if index.name else "" + ind_columns = ( + f' ({self.expressions(index, key="columns", flat=True)})' + if index.args.get("columns") + else "" + ) + indexes_sql.append( + f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}" + ) + index_sql = "".join(indexes_sql) modifiers = "".join( ( @@ -438,7 +475,10 @@ class Generator: materialized, ) ) - expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}" + + post_expression_modifiers = "".join((data, statistics, no_primary_index)) + + expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression: exp.Describe) -> str: @@ -668,6 +708,8 @@ class Generator: alias = self.sql(expression, "alias") alias = f"{sep}{alias}" if alias else "" + hints = self.expressions(expression, key="hints", sep=", ", flat=True) + hints = f" WITH ({hints})" if hints else "" laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") @@ -676,7 +718,7 @@ class Generator: pivots = f"{pivots}{alias}" alias = "" - return f"{table}{alias}{laterals}{joins}{pivots}" + return f"{table}{alias}{hints}{laterals}{joins}{pivots}" def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: @@ -1020,7 +1062,9 @@ class Generator: if not partition and not order and not spec and alias: return f"{this} {alias}" - return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" + window_args = alias + partition_sql + order_sql + spec_sql + + return f"{this} ({window_args.strip()})" def window_spec_sql(self, expression: exp.WindowSpec) -> str: kind = self.sql(expression, "kind") @@ -1130,6 +1174,8 @@ class Generator: query = expression.args.get("query") unnest = expression.args.get("unnest") field = expression.args.get("field") + is_global = " GLOBAL" if expression.args.get("is_global") else "" + if query: in_sql = self.wrap(query) elif unnest: @@ -1138,7 +1184,8 @@ class Generator: in_sql = self.sql(field) else: in_sql = f"({self.expressions(expression, flat=True)})" - return f"{self.sql(expression, 'this')} IN {in_sql}" + + return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}" def in_unnest_op(self, unnest: exp.Unnest) -> str: return f"(SELECT {self.sql(unnest)})" @@ -1433,7 +1480,7 @@ class Generator: result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) - comments = self.maybe_comment("", e) + comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else "" if self.pretty: if self._leading_comma: diff --git a/sqlglot/helper.py b/sqlglot/helper.py index ed37e6c..5a0f2ac 100644 --- a/sqlglot/helper.py +++ b/sqlglot/helper.py @@ -131,7 +131,7 @@ def subclasses( ] -def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: +def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]: """ Applies an offset to a given integer literal expression. @@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]: expression = expressions[0] - if expression.is_int: + if expression and expression.is_int: expression = expression.copy() logger.warning("Applying array index offset (%s)", offset) - expression.args["this"] = str(int(expression.this) + offset) + expression.args["this"] = str(int(expression.this) + offset) # type: ignore return [expression] return expressions @@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO: return gzip.open(file_name, "rt", newline="") - return open(file_name, "rt", encoding="utf-8", newline="") + return open(file_name, encoding="utf-8", newline="") @contextmanager @@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any: file.close() -def find_new_name(taken: t.Sequence[str], base: str) -> str: +def find_new_name(taken: t.Collection[str], base: str) -> str: """ Searches for a new name. @@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, yield value +def count_params(function: t.Callable) -> int: + """ + Returns the number of formal parameters expected by a function, without counting "self" + and "cls", in case of instance and class methods, respectively. + """ + count = function.__code__.co_argcount + return count - 1 if inspect.ismethod(function) else count + + def dict_depth(d: t.Dict) -> int: """ Get the nesting depth of a dictionary. @@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int: Args: d (dict): dictionary + Returns: int: depth """ diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py index be17f15..bfb2bb8 100644 --- a/sqlglot/optimizer/annotate_types.py +++ b/sqlglot/optimizer/annotate_types.py @@ -43,7 +43,7 @@ class TypeAnnotator: }, exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]), - exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr), + exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()), exp.Alias: lambda self, expr: self._annotate_unary(expr), exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN), diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index 3b40710..8e6a520 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias): # But columns in the ON clause shouldn't count. on = join.args.get("on") if on: - on_clause_columns = set(id(column) for column in on.find_all(exp.Column)) + on_clause_columns = {id(column) for column in on.find_all(exp.Column)} else: on_clause_columns = set() return any( @@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join): return False _, join_keys, _ = join_condition(join) - remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys) + remaining_unique_outputs = unique_outputs - {c.name for c in join_keys} return not remaining_unique_outputs diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py index 9ae4966..16aaf17 100644 --- a/sqlglot/optimizer/merge_subqueries.py +++ b/sqlglot/optimizer/merge_subqueries.py @@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False): singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1] for outer_scope, inner_scope, table in singular_cte_selections: - inner_select = inner_scope.expression.unnest() from_or_join = table.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): alias = table.alias_or_name - _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, table, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False): _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) _pop_cte(inner_scope) + outer_scope.clear_cache() return expression def merge_derived_tables(expression, leave_tables_isolated=False): for outer_scope in traverse_scope(expression): for subquery in outer_scope.derived_tables: - inner_select = subquery.unnest() from_or_join = subquery.find_ancestor(exp.From, exp.Join) - if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): - alias = subquery.alias_or_name - inner_scope = outer_scope.sources[alias] - + alias = subquery.alias_or_name + inner_scope = outer_scope.sources[alias] + if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): _rename_inner_sources(outer_scope, inner_scope, alias) _merge_from(outer_scope, inner_scope, subquery, alias) _merge_expressions(outer_scope, inner_scope, alias) @@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False): _merge_where(outer_scope, inner_scope, from_or_join) _merge_order(outer_scope, inner_scope) _merge_hints(outer_scope, inner_scope) + outer_scope.clear_cache() return expression -def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): +def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join): """ Return True if `inner_select` can be merged into outer query. Args: outer_scope (Scope) - inner_select (exp.Select) + inner_scope (Scope) leave_tables_isolated (bool) from_or_join (exp.From|exp.Join) Returns: bool: True if can be merged """ + inner_select = inner_scope.expression.unnest() def _is_a_window_expression_in_unmergable_operation(): window_expressions = inner_select.find_all(exp.Window) @@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): ] return any(window_expressions_in_unmergable) + def _outer_select_joins_on_inner_select_join(): + """ + All columns from the inner select in the ON clause must be from the first FROM table. + + That is, this can be merged: + SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + But this can't: + SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a + ^^^ ^ + """ + if not isinstance(from_or_join, exp.Join): + return False + + alias = from_or_join.this.alias_or_name + + on = from_or_join.args.get("on") + if not on: + return False + selections = [c.name for c in on.find_all(exp.Column) if c.table == alias] + inner_from = inner_scope.expression.args.get("from") + if not inner_from: + return False + inner_from_table = inner_from.expressions[0].alias_or_name + inner_projections = {s.alias_or_name: s for s in inner_scope.selects} + return any( + col.table != inner_from_table + for selection in selections + for col in inner_projections[selection].find_all(exp.Column) + ) + return ( isinstance(outer_scope.expression, exp.Select) and isinstance(inner_select, exp.Select) - and isinstance(inner_select, exp.Select) and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS) and inner_select.args.get("from") and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions) @@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join): j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []) ) ) + and not _outer_select_joins_on_inner_select_join() and not _is_a_window_expression_in_unmergable_operation() ) @@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias): """ taken = set(outer_scope.selected_sources) conflicts = taken.intersection(set(inner_scope.selected_sources)) - conflicts = conflicts - {alias} + conflicts -= {alias} for conflict in conflicts: new_name = find_new_name(taken, conflict) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 72e67d4..46b6b30 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables from sqlglot.optimizer.unnest_subqueries import unnest_subqueries +from sqlglot.schema import ensure_schema RULES = ( lower_identities, @@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar If no schema is provided then the default schema defined at `sqlgot.schema` will be used db (str): specify the default database, as might be set by a `USE DATABASE db` statement catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement - rules (list): sequence of optimizer rules to use + rules (sequence): sequence of optimizer rules to use **kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in. Returns: sqlglot.Expression: optimized expression """ - possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs} + schema = ensure_schema(schema or sqlglot.schema) + possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs} expression = expression.copy() for rule in rules: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index 49789ac..a73647c 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections): order_refs = set() new_selections = [] + removed = False for i, selection in enumerate(scope.selects): if ( SELECT_ALL in parent_selections @@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections): new_selections.append(selection) else: removed_indexes.append(i) + removed = True # If there are no remaining selections, just select a single constant if not new_selections: new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) + if removed: + scope.clear_cache() return removed_indexes diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e16a635..f4568c2 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -365,9 +365,9 @@ class _Resolver: def all_columns(self): """All available columns of all sources in this scope""" if self._all_columns is None: - self._all_columns = set( + self._all_columns = { column for columns in self._get_all_source_columns().values() for column in columns - ) + } return self._all_columns def get_source_columns(self, name, only_visible=False): diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index c0719f2..f560760 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b): return boolean elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): a, b = extract_date(a), extract_interval(b) - if b: + if a and b: if isinstance(expression, exp.Add): return date_literal(a + b) if isinstance(expression, exp.Sub): @@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b): elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): a, b = extract_interval(a), extract_date(b) # you cannot subtract a date from an interval - if a and isinstance(expression, exp.Add): + if a and b and isinstance(expression, exp.Add): return date_literal(a + b) return None @@ -424,9 +424,15 @@ def eval_boolean(expression, a, b): def extract_date(cast): - if cast.args["to"].this == exp.DataType.Type.DATE: - return datetime.date.fromisoformat(cast.name) - return None + # The "fromisoformat" conversion could fail if the cast is used on an identifier, + # so in that case we can't extract the date. + try: + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + if cast.args["to"].this == exp.DataType.Type.DATETIME: + return datetime.datetime.fromisoformat(cast.name) + except ValueError: + return None def extract_interval(interval): @@ -450,7 +456,8 @@ def extract_interval(interval): def date_literal(date): - return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE") + return exp.Cast(this=exp.Literal.string(date), to=expr_type) def boolean_literal(condition): diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 8d78294..a515489 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -15,8 +15,7 @@ def unnest_subqueries(expression): >>> import sqlglot >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() - 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1' Args: expression (sqlglot.Expression): expression to unnest @@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence): other = _other_operand(parent_predicate) if isinstance(parent_predicate, exp.Exists): - if value.this in group_by: - parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") - else: - parent_predicate = _replace(parent_predicate, "TRUE") + alias = exp.column(list(key_aliases.values())[0], table_alias) + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") elif isinstance(parent_predicate, exp.All): parent_predicate = _replace( parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" @@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence): else: if is_subquery_projection: alias = exp.alias_(alias, select.parent.alias) + + # COUNT always returns 0 on empty datasets, so we need take that into consideration here + # by transforming all counts into 0 and using that as the coalesced value + if value.find(exp.Count): + + def remove_aggs(node): + if isinstance(node, exp.Count): + return exp.Literal.number(0) + elif isinstance(node, exp.AggFunc): + return exp.null() + return node + + alias = exp.Coalesce( + this=alias, + expressions=[value.this.transform(remove_aggs)], + ) + select.parent.replace(alias) for key, column, predicate in keys: @@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence): if key in group_by: key.replace(nested) - parent_predicate = _replace( - parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" - ) elif isinstance(predicate, exp.EQ): parent_predicate = _replace( parent_predicate, @@ -245,7 +256,14 @@ def _other_operand(expression): if isinstance(expression, exp.In): return expression.this + if isinstance(expression, (exp.Any, exp.All)): + return _other_operand(expression.parent) + if isinstance(expression, exp.Binary): - return expression.right if expression.arg_key == "this" else expression.left + return ( + expression.right + if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All)) + else expression.left + ) return None diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 308f363..bd95db8 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5,7 +5,13 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get +from sqlglot.helper import ( + apply_index_offset, + count_params, + ensure_collection, + ensure_list, + seq_get, +) from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -54,7 +60,7 @@ class Parser(metaclass=_Parser): Default: "nulls_are_small" """ - FUNCTIONS = { + FUNCTIONS: t.Dict[str, t.Callable] = { **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, "DATE_TO_DATE_STR": lambda args: exp.Cast( this=seq_get(args, 0), @@ -106,6 +112,7 @@ class Parser(metaclass=_Parser): TokenType.JSON, TokenType.JSONB, TokenType.INTERVAL, + TokenType.TIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -164,6 +171,7 @@ class Parser(metaclass=_Parser): TokenType.DELETE, TokenType.DESCRIBE, TokenType.DETERMINISTIC, + TokenType.DIV, TokenType.DISTKEY, TokenType.DISTSTYLE, TokenType.EXECUTE, @@ -252,6 +260,7 @@ class Parser(metaclass=_Parser): TokenType.FIRST, TokenType.FORMAT, TokenType.IDENTIFIER, + TokenType.INDEX, TokenType.ISNULL, TokenType.MERGE, TokenType.OFFSET, @@ -312,6 +321,7 @@ class Parser(metaclass=_Parser): } TIMESTAMPS = { + TokenType.TIME, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -387,6 +397,7 @@ class Parser(metaclass=_Parser): } EXPRESSION_PARSERS = { + exp.Column: lambda self: self._parse_column(), exp.DataType: lambda self: self._parse_types(), exp.From: lambda self: self._parse_from(), exp.Group: lambda self: self._parse_group(), @@ -419,6 +430,7 @@ class Parser(metaclass=_Parser): TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), TokenType.DELETE: lambda self: self._parse_delete(), + TokenType.DESC: lambda self: self._parse_describe(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), TokenType.END: lambda self: self._parse_commit_or_rollback(), @@ -583,6 +595,11 @@ class Parser(metaclass=_Parser): TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} + WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS} + + # allows tables to have special tokens as prefixes + TABLE_PREFIX_TOKENS: t.Set[TokenType] = set() + STRICT_CAST = True __slots__ = ( @@ -608,13 +625,13 @@ class Parser(metaclass=_Parser): def __init__( self, - error_level=None, - error_message_context=100, - index_offset=0, - unnest_column_only=False, - alias_post_tablesample=False, - max_errors=3, - null_ordering=None, + error_level: t.Optional[ErrorLevel] = None, + error_message_context: int = 100, + index_offset: int = 0, + unnest_column_only: bool = False, + alias_post_tablesample: bool = False, + max_errors: int = 3, + null_ordering: t.Optional[str] = None, ): self.error_level = error_level or ErrorLevel.IMMEDIATE self.error_message_context = error_message_context @@ -636,23 +653,43 @@ class Parser(metaclass=_Parser): self._prev = None self._prev_comments = None - def parse(self, raw_tokens, sql=None): + def parse( + self, raw_tokens: t.List[Token], sql: t.Optional[str] = None + ) -> t.List[t.Optional[exp.Expression]]: """ - Parses the given list of tokens and returns a list of syntax trees, one tree + Parses a list of tokens and returns a list of syntax trees, one tree per parsed SQL statement. - Args - raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`). - sql (str): the original SQL string. Used to produce helpful debug messages. + Args: + raw_tokens: the list of tokens. + sql: the original SQL string, used to produce helpful debug messages. - Returns - the list of syntax trees (:class:`~sqlglot.expressions.Expression`). + Returns: + The list of syntax trees. """ return self._parse( parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql ) - def parse_into(self, expression_types, raw_tokens, sql=None): + def parse_into( + self, + expression_types: str | exp.Expression | t.Collection[exp.Expression | str], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: + """ + Parses a list of tokens into a given Expression type. If a collection of Expression + types is given instead, this method will try to parse the token list into each one + of them, stopping at the first for which the parsing succeeds. + + Args: + expression_types: the expression type(s) to try and parse the token list into. + raw_tokens: the list of tokens. + sql: the original SQL string, used to produce helpful debug messages. + + Returns: + The target Expression. + """ errors = [] for expression_type in ensure_collection(expression_types): parser = self.EXPRESSION_PARSERS.get(expression_type) @@ -668,7 +705,12 @@ class Parser(metaclass=_Parser): errors=merge_errors(errors), ) from errors[-1] - def _parse(self, parse_method, raw_tokens, sql=None): + def _parse( + self, + parse_method: t.Callable[[Parser], t.Optional[exp.Expression]], + raw_tokens: t.List[Token], + sql: t.Optional[str] = None, + ) -> t.List[t.Optional[exp.Expression]]: self.reset() self.sql = sql or "" total = len(raw_tokens) @@ -686,6 +728,7 @@ class Parser(metaclass=_Parser): self._index = -1 self._tokens = tokens self._advance() + expressions.append(parse_method(self)) if self._index < len(self._tokens): @@ -695,7 +738,10 @@ class Parser(metaclass=_Parser): return expressions - def check_errors(self): + def check_errors(self) -> None: + """ + Logs or raises any found errors, depending on the chosen error level setting. + """ if self.error_level == ErrorLevel.WARN: for error in self.errors: logger.error(str(error)) @@ -705,13 +751,18 @@ class Parser(metaclass=_Parser): errors=merge_errors(self.errors), ) - def raise_error(self, message, token=None): + def raise_error(self, message: str, token: t.Optional[Token] = None) -> None: + """ + Appends an error in the list of recorded errors or raises it, depending on the chosen + error level setting. + """ token = token or self._curr or self._prev or Token.string("") start = self._find_token(token, self.sql) end = start + len(token.text) start_context = self.sql[max(start - self.error_message_context, 0) : start] highlight = self.sql[start:end] end_context = self.sql[end : end + self.error_message_context] + error = ParseError.new( f"{message}. Line {token.line}, Col: {token.col}.\n" f" {start_context}\033[4m{highlight}\033[0m{end_context}", @@ -722,11 +773,26 @@ class Parser(metaclass=_Parser): highlight=highlight, end_context=end_context, ) + if self.error_level == ErrorLevel.IMMEDIATE: raise error + self.errors.append(error) - def expression(self, exp_class, comments=None, **kwargs): + def expression( + self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs + ) -> exp.Expression: + """ + Creates a new, validated Expression. + + Args: + exp_class: the expression class to instantiate. + comments: an optional list of comments to attach to the expression. + kwargs: the arguments to set for the expression along with their respective values. + + Returns: + The target expression. + """ instance = exp_class(**kwargs) if self._prev_comments: instance.comments = self._prev_comments @@ -736,7 +802,17 @@ class Parser(metaclass=_Parser): self.validate_expression(instance) return instance - def validate_expression(self, expression, args=None): + def validate_expression( + self, expression: exp.Expression, args: t.Optional[t.List] = None + ) -> None: + """ + Validates an already instantiated expression, making sure that all its mandatory arguments + are set. + + Args: + expression: the expression to validate. + args: an optional list of items that was used to instantiate the expression, if it's a Func. + """ if self.error_level == ErrorLevel.IGNORE: return @@ -748,13 +824,18 @@ class Parser(metaclass=_Parser): if mandatory and (v is None or (isinstance(v, list) and not v)): self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}") - if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args: + if ( + args + and isinstance(expression, exp.Func) + and len(args) > len(expression.arg_types) + and not expression.is_var_len_args + ): self.raise_error( f"The number of provided arguments ({len(args)}) is greater than " f"the maximum number of supported arguments ({len(expression.arg_types)})" ) - def _find_token(self, token, sql): + def _find_token(self, token: Token, sql: str) -> int: line = 1 col = 1 index = 0 @@ -769,7 +850,7 @@ class Parser(metaclass=_Parser): return index - def _advance(self, times=1): + def _advance(self, times: int = 1) -> None: self._index += times self._curr = seq_get(self._tokens, self._index) self._next = seq_get(self._tokens, self._index + 1) @@ -780,10 +861,10 @@ class Parser(metaclass=_Parser): self._prev = None self._prev_comments = None - def _retreat(self, index): + def _retreat(self, index: int) -> None: self._advance(index - self._index) - def _parse_statement(self): + def _parse_statement(self) -> t.Optional[exp.Expression]: if self._curr is None: return None @@ -803,7 +884,7 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self, default_kind=None): + def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]: temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text @@ -812,7 +893,7 @@ class Parser(metaclass=_Parser): kind = default_kind else: self.raise_error(f"Expected {self.CREATABLES}") - return + return None return self.expression( exp.Drop, @@ -824,14 +905,14 @@ class Parser(metaclass=_Parser): cascade=self._match(TokenType.CASCADE), ) - def _parse_exists(self, not_=False): + def _parse_exists(self, not_: bool = False) -> t.Optional[bool]: return ( self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS) ) - def _parse_create(self): + def _parse_create(self) -> t.Optional[exp.Expression]: replace = self._match_pair(TokenType.OR, TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) transient = self._match_text_seq("TRANSIENT") @@ -846,12 +927,16 @@ class Parser(metaclass=_Parser): if not create_token: self.raise_error(f"Expected {self.CREATABLES}") - return + return None exists = self._parse_exists(not_=True) this = None expression = None properties = None + data = None + statistics = None + no_primary_index = None + indexes = None if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE): this = self._parse_user_defined_function() @@ -868,7 +953,28 @@ class Parser(metaclass=_Parser): this = self._parse_table(schema=True) properties = self._parse_properties() if self._match(TokenType.ALIAS): - expression = self._parse_select(nested=True) + expression = self._parse_ddl_select() + + if create_token.token_type == TokenType.TABLE: + if self._match_text_seq("WITH", "DATA"): + data = True + elif self._match_text_seq("WITH", "NO", "DATA"): + data = False + + if self._match_text_seq("AND", "STATISTICS"): + statistics = True + elif self._match_text_seq("AND", "NO", "STATISTICS"): + statistics = False + + no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX") + + indexes = [] + while True: + index = self._parse_create_table_index() + if not index: + break + else: + indexes.append(index) return self.expression( exp.Create, @@ -883,9 +989,13 @@ class Parser(metaclass=_Parser): replace=replace, unique=unique, materialized=materialized, + data=data, + statistics=statistics, + no_primary_index=no_primary_index, + indexes=indexes, ) - def _parse_property(self): + def _parse_property(self) -> t.Optional[exp.Expression]: if self._match_set(self.PROPERTY_PARSERS): return self.PROPERTY_PARSERS[self._prev.token_type](self) @@ -906,7 +1016,7 @@ class Parser(metaclass=_Parser): return None - def _parse_property_assignment(self, exp_class): + def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression: self._match(TokenType.EQ) self._match(TokenType.ALIAS) return self.expression( @@ -914,42 +1024,50 @@ class Parser(metaclass=_Parser): this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(), ) - def _parse_partitioned_by(self): + def _parse_partitioned_by(self) -> exp.Expression: self._match(TokenType.EQ) return self.expression( exp.PartitionedByProperty, this=self._parse_schema() or self._parse_bracket(self._parse_field()), ) - def _parse_distkey(self): + def _parse_distkey(self) -> exp.Expression: return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var)) - def _parse_create_like(self): + def _parse_create_like(self) -> t.Optional[exp.Expression]: table = self._parse_table(schema=True) options = [] while self._match_texts(("INCLUDING", "EXCLUDING")): + this = self._prev.text.upper() + id_var = self._parse_id_var() + + if not id_var: + return None + options.append( self.expression( exp.Property, - this=self._prev.text.upper(), - value=exp.Var(this=self._parse_id_var().this.upper()), + this=this, + value=exp.Var(this=id_var.this.upper()), ) ) return self.expression(exp.LikeProperty, this=table, expressions=options) - def _parse_sortkey(self, compound=False): + def _parse_sortkey(self, compound: bool = False) -> exp.Expression: return self.expression( exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound ) - def _parse_character_set(self, default=False): + def _parse_character_set(self, default: bool = False) -> exp.Expression: self._match(TokenType.EQ) return self.expression( exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default ) - def _parse_returns(self): + def _parse_returns(self) -> exp.Expression: + value: t.Optional[exp.Expression] is_table = self._match(TokenType.TABLE) + if is_table: if self._match(TokenType.LT): value = self.expression( @@ -960,13 +1078,13 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") else: - value = self._parse_schema("TABLE") + value = self._parse_schema(exp.Literal.string("TABLE")) else: value = self._parse_types() return self.expression(exp.ReturnsProperty, this=value, is_table=is_table) - def _parse_properties(self): + def _parse_properties(self) -> t.Optional[exp.Expression]: properties = [] while True: @@ -978,15 +1096,21 @@ class Parser(metaclass=_Parser): if properties: return self.expression(exp.Properties, expressions=properties) + return None - def _parse_describe(self): - self._match(TokenType.TABLE) - return self.expression(exp.Describe, this=self._parse_id_var()) + def _parse_describe(self) -> exp.Expression: + kind = self._match_set(self.CREATABLES) and self._prev.text + this = self._parse_table() - def _parse_insert(self): + return self.expression(exp.Describe, this=this, kind=kind) + + def _parse_insert(self) -> exp.Expression: overwrite = self._match(TokenType.OVERWRITE) local = self._match(TokenType.LOCAL) + + this: t.Optional[exp.Expression] + if self._match_text_seq("DIRECTORY"): this = self.expression( exp.Directory, @@ -998,21 +1122,22 @@ class Parser(metaclass=_Parser): self._match(TokenType.INTO) self._match(TokenType.TABLE) this = self._parse_table(schema=True) + return self.expression( exp.Insert, this=this, exists=self._parse_exists(), partition=self._parse_partition(), - expression=self._parse_select(nested=True), + expression=self._parse_ddl_select(), overwrite=overwrite, ) - def _parse_row(self): + def _parse_row(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FORMAT): return None return self._parse_row_format() - def _parse_row_format(self, match_row=False): + def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]: if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT): return None @@ -1035,9 +1160,10 @@ class Parser(metaclass=_Parser): kwargs["lines"] = self._parse_string() if self._match_text_seq("NULL", "DEFINED", "AS"): kwargs["null"] = self._parse_string() - return self.expression(exp.RowFormatDelimitedProperty, **kwargs) - def _parse_load_data(self): + return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore + + def _parse_load_data(self) -> exp.Expression: local = self._match(TokenType.LOCAL) self._match_text_seq("INPATH") inpath = self._parse_string() @@ -1055,7 +1181,7 @@ class Parser(metaclass=_Parser): serde=self._match_text_seq("SERDE") and self._parse_string(), ) - def _parse_delete(self): + def _parse_delete(self) -> exp.Expression: self._match(TokenType.FROM) return self.expression( @@ -1065,10 +1191,10 @@ class Parser(metaclass=_Parser): where=self._parse_where(), ) - def _parse_update(self): + def _parse_update(self) -> exp.Expression: return self.expression( exp.Update, - **{ + **{ # type: ignore "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS), "expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality), "from": self._parse_from(), @@ -1076,16 +1202,17 @@ class Parser(metaclass=_Parser): }, ) - def _parse_uncache(self): + def _parse_uncache(self) -> exp.Expression: if not self._match(TokenType.TABLE): self.raise_error("Expecting TABLE after UNCACHE") + return self.expression( exp.Uncache, exists=self._parse_exists(), this=self._parse_table(schema=True), ) - def _parse_cache(self): + def _parse_cache(self) -> exp.Expression: lazy = self._match(TokenType.LAZY) self._match(TokenType.TABLE) table = self._parse_table(schema=True) @@ -1108,21 +1235,23 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_partition(self): + def _parse_partition(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.PARTITION): return None - def parse_values(): + def parse_values() -> exp.Property: props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ) return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1)) return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values)) - def _parse_value(self): + def _parse_value(self) -> exp.Expression: expressions = self._parse_wrapped_csv(self._parse_conjunction) return self.expression(exp.Tuple, expressions=expressions) - def _parse_select(self, nested=False, table=False): + def _parse_select( + self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True + ) -> t.Optional[exp.Expression]: cte = self._parse_with() if cte: this = self._parse_statement() @@ -1178,10 +1307,11 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(this) this = self._parse_set_operations(this) self._match_r_paren() + # early return so that subquery unions aren't parsed again # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 # Union ALL should be a property of the top select node, not the subquery - return self._parse_subquery(this) + return self._parse_subquery(this, parse_alias=parse_subquery_alias) elif self._match(TokenType.VALUES): if self._curr.token_type == TokenType.L_PAREN: # We don't consume the left paren because it's consumed in _parse_value @@ -1203,7 +1333,7 @@ class Parser(metaclass=_Parser): return self._parse_set_operations(this) - def _parse_with(self, skip_with_token=False): + def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]: if not skip_with_token and not self._match(TokenType.WITH): return None @@ -1220,7 +1350,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.With, expressions=expressions, recursive=recursive) - def _parse_cte(self): + def _parse_cte(self) -> exp.Expression: alias = self._parse_table_alias() if not alias or not alias.this: self.raise_error("Expected CTE to have alias") @@ -1234,7 +1364,9 @@ class Parser(metaclass=_Parser): alias=alias, ) - def _parse_table_alias(self, alias_tokens=None): + def _parse_table_alias( + self, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS @@ -1251,15 +1383,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.TableAlias, this=alias, columns=columns) - def _parse_subquery(self, this): + def _parse_subquery( + self, this: t.Optional[exp.Expression], parse_alias: bool = True + ) -> exp.Expression: return self.expression( exp.Subquery, this=this, pivots=self._parse_pivots(), - alias=self._parse_table_alias(), + alias=self._parse_table_alias() if parse_alias else None, ) - def _parse_query_modifiers(self, this): + def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None: if not isinstance(this, self.MODIFIABLES): return @@ -1284,15 +1418,16 @@ class Parser(metaclass=_Parser): if expression: this.set(key, expression) - def _parse_hint(self): + def _parse_hint(self) -> t.Optional[exp.Expression]: if self._match(TokenType.HINT): hints = self._parse_csv(self._parse_function) if not self._match_pair(TokenType.STAR, TokenType.SLASH): self.raise_error("Expected */ after HINT") return self.expression(exp.Hint, expressions=hints) + return None - def _parse_into(self): + def _parse_into(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.INTO): return None @@ -1304,14 +1439,15 @@ class Parser(metaclass=_Parser): exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged ) - def _parse_from(self): + def _parse_from(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.FROM): return None + return self.expression( exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table) ) - def _parse_lateral(self): + def _parse_lateral(self) -> t.Optional[exp.Expression]: outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) @@ -1334,6 +1470,8 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) + table_alias: t.Optional[exp.Expression] + if view: table = self._parse_id_var(any_token=False) columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] @@ -1354,20 +1492,24 @@ class Parser(metaclass=_Parser): return expression - def _parse_join_side_and_kind(self): + def _parse_join_side_and_kind( + self, + ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]: return ( self._match(TokenType.NATURAL) and self._prev, self._match_set(self.JOIN_SIDES) and self._prev, self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self, skip_join_token=False): + def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]: natural, side, kind = self._parse_join_side_and_kind() if not skip_join_token and not self._match(TokenType.JOIN): return None - kwargs = {"this": self._parse_table()} + kwargs: t.Dict[ + str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]] + ] = {"this": self._parse_table()} if natural: kwargs["natural"] = True @@ -1381,12 +1523,13 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.USING): kwargs["using"] = self._parse_wrapped_id_vars() - return self.expression(exp.Join, **kwargs) + return self.expression(exp.Join, **kwargs) # type: ignore - def _parse_index(self): + def _parse_index(self) -> exp.Expression: index = self._parse_id_var() self._match(TokenType.ON) self._match(TokenType.TABLE) # hive + return self.expression( exp.Index, this=index, @@ -1394,7 +1537,28 @@ class Parser(metaclass=_Parser): columns=self._parse_expression(), ) - def _parse_table(self, schema=False, alias_tokens=None): + def _parse_create_table_index(self) -> t.Optional[exp.Expression]: + unique = self._match(TokenType.UNIQUE) + primary = self._match_text_seq("PRIMARY") + amp = self._match_text_seq("AMP") + if not self._match(TokenType.INDEX): + return None + index = self._parse_id_var() + columns = None + if self._curr and self._curr.token_type == TokenType.L_PAREN: + columns = self._parse_wrapped_csv(self._parse_column) + return self.expression( + exp.Index, + this=index, + columns=columns, + unique=unique, + primary=primary, + amp=amp, + ) + + def _parse_table( + self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None + ) -> t.Optional[exp.Expression]: lateral = self._parse_lateral() if lateral: @@ -1417,7 +1581,9 @@ class Parser(metaclass=_Parser): catalog = None db = None - table = (not schema and self._parse_function()) or self._parse_id_var(False) + table = (not schema and self._parse_function()) or self._parse_id_var( + any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS + ) while self._match(TokenType.DOT): if catalog: @@ -1446,6 +1612,14 @@ class Parser(metaclass=_Parser): if alias: this.set("alias", alias) + if self._match(TokenType.WITH): + this.set( + "hints", + self._parse_wrapped_csv( + lambda: self._parse_function() or self._parse_var(any_token=True) + ), + ) + if not self.alias_post_tablesample: table_sample = self._parse_table_sample() @@ -1455,7 +1629,7 @@ class Parser(metaclass=_Parser): return this - def _parse_unnest(self): + def _parse_unnest(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.UNNEST): return None @@ -1473,7 +1647,7 @@ class Parser(metaclass=_Parser): exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias ) - def _parse_derived_table_values(self): + def _parse_derived_table_values(self) -> t.Optional[exp.Expression]: is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES) if not is_derived and not self._match(TokenType.VALUES): return None @@ -1485,7 +1659,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias()) - def _parse_table_sample(self): + def _parse_table_sample(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE_SAMPLE): return None @@ -1533,10 +1707,10 @@ class Parser(metaclass=_Parser): seed=seed, ) - def _parse_pivots(self): + def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]: return list(iter(self._parse_pivot, None)) - def _parse_pivot(self): + def _parse_pivot(self) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.PIVOT): @@ -1572,16 +1746,18 @@ class Parser(metaclass=_Parser): return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) - def _parse_where(self, skip_where_token=False): + def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]: if not skip_where_token and not self._match(TokenType.WHERE): return None + return self.expression( exp.Where, comments=self._prev_comments, this=self._parse_conjunction() ) - def _parse_group(self, skip_group_by_token=False): + def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]: if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None + return self.expression( exp.Group, expressions=self._parse_csv(self._parse_conjunction), @@ -1590,29 +1766,33 @@ class Parser(metaclass=_Parser): rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(), ) - def _parse_grouping_sets(self): + def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.GROUPING_SETS): return None + return self._parse_wrapped_csv(self._parse_grouping_set) - def _parse_grouping_set(self): + def _parse_grouping_set(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): grouping_set = self._parse_csv(self._parse_id_var) self._match_r_paren() return self.expression(exp.Tuple, expressions=grouping_set) + return self._parse_id_var() - def _parse_having(self, skip_having_token=False): + def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]: if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) - def _parse_qualify(self): + def _parse_qualify(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.QUALIFY): return None return self.expression(exp.Qualify, this=self._parse_conjunction()) - def _parse_order(self, this=None, skip_order_token=False): + def _parse_order( + self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False + ) -> t.Optional[exp.Expression]: if not skip_order_token and not self._match(TokenType.ORDER_BY): return this @@ -1620,12 +1800,14 @@ class Parser(metaclass=_Parser): exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) ) - def _parse_sort(self, token_type, exp_class): + def _parse_sort( + self, token_type: TokenType, exp_class: t.Type[exp.Expression] + ) -> t.Optional[exp.Expression]: if not self._match(token_type): return None return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered)) - def _parse_ordered(self): + def _parse_ordered(self) -> exp.Expression: this = self._parse_conjunction() self._match(TokenType.ASC) is_desc = self._match(TokenType.DESC) @@ -1647,7 +1829,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first) - def _parse_limit(self, this=None, top=False): + def _parse_limit( + self, this: t.Optional[exp.Expression] = None, top: bool = False + ) -> t.Optional[exp.Expression]: if self._match(TokenType.TOP if top else TokenType.LIMIT): limit_paren = self._match(TokenType.L_PAREN) limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) @@ -1667,7 +1851,7 @@ class Parser(metaclass=_Parser): return this - def _parse_offset(self, this=None): + def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: if not self._match_set((TokenType.OFFSET, TokenType.COMMA)): return this @@ -1675,7 +1859,7 @@ class Parser(metaclass=_Parser): self._match_set((TokenType.ROW, TokenType.ROWS)) return self.expression(exp.Offset, this=this, expression=count) - def _parse_set_operations(self, this): + def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match_set(self.SET_OPERATIONS): return this @@ -1695,19 +1879,19 @@ class Parser(metaclass=_Parser): expression=self._parse_select(nested=True), ) - def _parse_expression(self): + def _parse_expression(self) -> t.Optional[exp.Expression]: return self._parse_alias(self._parse_conjunction()) - def _parse_conjunction(self): + def _parse_conjunction(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_equality, self.CONJUNCTION) - def _parse_equality(self): + def _parse_equality(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_comparison, self.EQUALITY) - def _parse_comparison(self): + def _parse_comparison(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_range, self.COMPARISON) - def _parse_range(self): + def _parse_range(self) -> t.Optional[exp.Expression]: this = self._parse_bitwise() negate = self._match(TokenType.NOT) @@ -1730,7 +1914,7 @@ class Parser(metaclass=_Parser): return this - def _parse_is(self, this): + def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression: negate = self._match(TokenType.NOT) if self._match(TokenType.DISTINCT_FROM): klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ @@ -1743,7 +1927,7 @@ class Parser(metaclass=_Parser): ) return self.expression(exp.Not, this=this) if negate else this - def _parse_in(self, this): + def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression: unnest = self._parse_unnest() if unnest: this = self.expression(exp.In, this=this, unnest=unnest) @@ -1761,18 +1945,18 @@ class Parser(metaclass=_Parser): return this - def _parse_between(self, this): + def _parse_between(self, this: exp.Expression) -> exp.Expression: low = self._parse_bitwise() self._match(TokenType.AND) high = self._parse_bitwise() return self.expression(exp.Between, this=this, low=low, high=high) - def _parse_escape(self, this): + def _parse_escape(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.ESCAPE): return this return self.expression(exp.Escape, this=this, expression=self._parse_string()) - def _parse_bitwise(self): + def _parse_bitwise(self) -> t.Optional[exp.Expression]: this = self._parse_term() while True: @@ -1795,18 +1979,18 @@ class Parser(metaclass=_Parser): return this - def _parse_term(self): + def _parse_term(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_factor, self.TERM) - def _parse_factor(self): + def _parse_factor(self) -> t.Optional[exp.Expression]: return self._parse_tokens(self._parse_unary, self.FACTOR) - def _parse_unary(self): + def _parse_unary(self) -> t.Optional[exp.Expression]: if self._match_set(self.UNARY_PARSERS): return self.UNARY_PARSERS[self._prev.token_type](self) return self._parse_at_time_zone(self._parse_type()) - def _parse_type(self): + def _parse_type(self) -> t.Optional[exp.Expression]: if self._match(TokenType.INTERVAL): return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var()) @@ -1824,7 +2008,7 @@ class Parser(metaclass=_Parser): return this - def _parse_types(self, check_func=False): + def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: index = self._index if not self._match_set(self.TYPE_TOKENS): @@ -1875,7 +2059,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.GT): self.raise_error("Expecting >") - value = None + value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ: value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) @@ -1884,7 +2068,10 @@ class Parser(metaclass=_Parser): ): value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match(TokenType.WITHOUT_TIME_ZONE): - value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + if type_token == TokenType.TIME: + value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions) + else: + value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) maybe_func = maybe_func and value is None @@ -1912,7 +2099,7 @@ class Parser(metaclass=_Parser): nested=nested, ) - def _parse_struct_kwargs(self): + def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() self._match(TokenType.COLON) data_type = self._parse_types() @@ -1921,12 +2108,12 @@ class Parser(metaclass=_Parser): return None return self.expression(exp.StructKwarg, this=this, expression=data_type) - def _parse_at_time_zone(self, this): + def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.AT_TIME_ZONE): return this return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) - def _parse_column(self): + def _parse_column(self) -> t.Optional[exp.Expression]: this = self._parse_field() if isinstance(this, exp.Identifier): this = self.expression(exp.Column, this=this) @@ -1943,7 +2130,8 @@ class Parser(metaclass=_Parser): if not field: self.raise_error("Expected type") elif op: - field = exp.Literal.string(self._advance() or self._prev.text) + self._advance() + field = exp.Literal.string(self._prev.text) else: field = self._parse_star() or self._parse_function() or self._parse_id_var() @@ -1963,7 +2151,7 @@ class Parser(metaclass=_Parser): return this - def _parse_primary(self): + def _parse_primary(self) -> t.Optional[exp.Expression]: if self._match_set(self.PRIMARY_PARSERS): token_type = self._prev.token_type primary = self.PRIMARY_PARSERS[token_type](self, self._prev) @@ -1995,21 +2183,27 @@ class Parser(metaclass=_Parser): self._match_r_paren() if isinstance(this, exp.Subqueryable): - this = self._parse_set_operations(self._parse_subquery(this)) + this = self._parse_set_operations( + self._parse_subquery(this=this, parse_alias=False) + ) elif len(expressions) > 1: this = self.expression(exp.Tuple, expressions=expressions) else: this = self.expression(exp.Paren, this=this) - if comments: + + if this and comments: this.comments = comments + return this return None - def _parse_field(self, any_token=False): + def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]: return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token) - def _parse_function(self, functions=None): + def _parse_function( + self, functions: t.Optional[t.Dict[str, t.Callable]] = None + ) -> t.Optional[exp.Expression]: if not self._curr: return None @@ -2020,7 +2214,9 @@ class Parser(metaclass=_Parser): if not self._next or self._next.token_type != TokenType.L_PAREN: if token_type in self.NO_PAREN_FUNCTIONS: - return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type]) + self._advance() + return self.expression(self.NO_PAREN_FUNCTIONS[token_type]) + return None if token_type not in self.FUNC_TOKENS: @@ -2049,7 +2245,18 @@ class Parser(metaclass=_Parser): args = self._parse_csv(self._parse_lambda) if function: - this = function(args) + + # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the + # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists. + if count_params(function) == 2: + params = None + if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN): + params = self._parse_csv(self._parse_lambda) + + this = function(args, params) + else: + this = function(args) + self.validate_expression(this, args) else: this = self.expression(exp.Anonymous, this=this, expressions=args) @@ -2057,7 +2264,7 @@ class Parser(metaclass=_Parser): self._match_r_paren(this) return self._parse_window(this) - def _parse_user_defined_function(self): + def _parse_user_defined_function(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() while self._match(TokenType.DOT): @@ -2070,27 +2277,27 @@ class Parser(metaclass=_Parser): self._match_r_paren() return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions) - def _parse_introducer(self, token): + def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]: literal = self._parse_primary() if literal: return self.expression(exp.Introducer, this=token.text, expression=literal) return self.expression(exp.Identifier, this=token.text) - def _parse_national(self, token): + def _parse_national(self, token: Token) -> exp.Expression: return self.expression(exp.National, this=exp.Literal.string(token.text)) - def _parse_session_parameter(self): + def _parse_session_parameter(self) -> exp.Expression: kind = None this = self._parse_id_var() or self._parse_primary() - if self._match(TokenType.DOT): + if this and self._match(TokenType.DOT): kind = this.name this = self._parse_var() or self._parse_primary() return self.expression(exp.SessionParameter, this=this, kind=kind) - def _parse_udf_kwarg(self): + def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]: this = self._parse_id_var() kind = self._parse_types() @@ -2099,7 +2306,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind) - def _parse_lambda(self): + def _parse_lambda(self) -> t.Optional[exp.Expression]: index = self._index if self._match(TokenType.L_PAREN): @@ -2115,6 +2322,8 @@ class Parser(metaclass=_Parser): self._retreat(index) + this: t.Optional[exp.Expression] + if self._match(TokenType.DISTINCT): this = self.expression( exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) @@ -2129,7 +2338,7 @@ class Parser(metaclass=_Parser): return self._parse_limit(self._parse_order(this)) - def _parse_schema(self, this=None): + def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]: index = self._index if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): self._retreat(index) @@ -2140,14 +2349,15 @@ class Parser(metaclass=_Parser): or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() + + if isinstance(this, exp.Literal): + this = this.name + return self.expression(exp.Schema, this=this, expressions=args) - def _parse_column_def(self, this): + def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: kind = self._parse_types() - if not kind: - return this - constraints = [] while True: constraint = self._parse_column_constraint() @@ -2155,9 +2365,12 @@ class Parser(metaclass=_Parser): break constraints.append(constraint) + if not kind and not constraints: + return this + return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints) - def _parse_column_constraint(self): + def _parse_column_constraint(self) -> t.Optional[exp.Expression]: this = self._parse_references() if this: @@ -2166,6 +2379,8 @@ class Parser(metaclass=_Parser): if self._match(TokenType.CONSTRAINT): this = self._parse_id_var() + kind: exp.Expression + if self._match(TokenType.AUTO_INCREMENT): kind = exp.AutoIncrementColumnConstraint() elif self._match(TokenType.CHECK): @@ -2202,7 +2417,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.ColumnConstraint, this=this, kind=kind) - def _parse_constraint(self): + def _parse_constraint(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.CONSTRAINT): return self._parse_unnamed_constraint() @@ -2217,24 +2432,25 @@ class Parser(metaclass=_Parser): return self.expression(exp.Constraint, this=this, expressions=expressions) - def _parse_unnamed_constraint(self): + def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]: if not self._match_set(self.CONSTRAINT_PARSERS): return None return self.CONSTRAINT_PARSERS[self._prev.token_type](self) - def _parse_unique(self): + def _parse_unique(self) -> exp.Expression: return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars()) - def _parse_references(self): + def _parse_references(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.REFERENCES): return None + return self.expression( exp.Reference, this=self._parse_id_var(), expressions=self._parse_wrapped_id_vars(), ) - def _parse_foreign_key(self): + def _parse_foreign_key(self) -> exp.Expression: expressions = self._parse_wrapped_id_vars() reference = self._parse_references() options = {} @@ -2260,13 +2476,15 @@ class Parser(metaclass=_Parser): exp.ForeignKey, expressions=expressions, reference=reference, - **options, + **options, # type: ignore ) - def _parse_bracket(self, this): + def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if not self._match(TokenType.L_BRACKET): return this + expressions: t.List[t.Optional[exp.Expression]] + if self._match(TokenType.COLON): expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] else: @@ -2284,12 +2502,12 @@ class Parser(metaclass=_Parser): this.comments = self._prev_comments return self._parse_bracket(this) - def _parse_slice(self, this): + def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: if self._match(TokenType.COLON): return self.expression(exp.Slice, this=this, expression=self._parse_conjunction()) return this - def _parse_case(self): + def _parse_case(self) -> t.Optional[exp.Expression]: ifs = [] default = None @@ -2311,7 +2529,7 @@ class Parser(metaclass=_Parser): self.expression(exp.Case, this=expression, ifs=ifs, default=default) ) - def _parse_if(self): + def _parse_if(self) -> t.Optional[exp.Expression]: if self._match(TokenType.L_PAREN): args = self._parse_csv(self._parse_conjunction) this = exp.If.from_arg_list(args) @@ -2324,9 +2542,10 @@ class Parser(metaclass=_Parser): false = self._parse_conjunction() if self._match(TokenType.ELSE) else None self._match(TokenType.END) this = self.expression(exp.If, this=condition, true=true, false=false) + return self._parse_window(this) - def _parse_extract(self): + def _parse_extract(self) -> exp.Expression: this = self._parse_function() or self._parse_var() or self._parse_type() if self._match(TokenType.FROM): @@ -2337,7 +2556,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Extract, this=this, expression=self._parse_bitwise()) - def _parse_cast(self, strict): + def _parse_cast(self, strict: bool) -> exp.Expression: this = self._parse_conjunction() if not self._match(TokenType.ALIAS): @@ -2353,7 +2572,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_string_agg(self): + def _parse_string_agg(self) -> exp.Expression: + expression: t.Optional[exp.Expression] + if self._match(TokenType.DISTINCT): args = self._parse_csv(self._parse_conjunction) expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)]) @@ -2380,8 +2601,10 @@ class Parser(metaclass=_Parser): order = self._parse_order(this=expression) return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1)) - def _parse_convert(self, strict): + def _parse_convert(self, strict: bool) -> exp.Expression: + to: t.Optional[exp.Expression] this = self._parse_column() + if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) elif self._match(TokenType.COMMA): @@ -2390,7 +2613,7 @@ class Parser(metaclass=_Parser): to = None return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) - def _parse_position(self): + def _parse_position(self) -> exp.Expression: args = self._parse_csv(self._parse_bitwise) if self._match(TokenType.IN): @@ -2402,11 +2625,11 @@ class Parser(metaclass=_Parser): return this - def _parse_join_hint(self, func_name): + def _parse_join_hint(self, func_name: str) -> exp.Expression: args = self._parse_csv(self._parse_table) return exp.JoinHint(this=func_name.upper(), expressions=args) - def _parse_substring(self): + def _parse_substring(self) -> exp.Expression: # Postgres supports the form: substring(string [from int] [for int]) # https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6 @@ -2422,7 +2645,7 @@ class Parser(metaclass=_Parser): return this - def _parse_trim(self): + def _parse_trim(self) -> exp.Expression: # https://www.w3resource.com/sql/character-functions/trim.php # https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html @@ -2450,13 +2673,15 @@ class Parser(metaclass=_Parser): collation=collation, ) - def _parse_window_clause(self): + def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) - def _parse_named_window(self): + def _parse_named_window(self) -> t.Optional[exp.Expression]: return self._parse_window(self._parse_id_var(), alias=True) - def _parse_window(self, this, alias=False): + def _parse_window( + self, this: t.Optional[exp.Expression], alias: bool = False + ) -> t.Optional[exp.Expression]: if self._match(TokenType.FILTER): where = self._parse_wrapped(self._parse_where) this = self.expression(exp.Filter, this=this, expression=where) @@ -2495,7 +2720,7 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_PAREN): return self.expression(exp.Window, this=this, alias=self._parse_id_var(False)) - alias = self._parse_id_var(False) + window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS) partition = None if self._match(TokenType.PARTITION_BY): @@ -2529,10 +2754,10 @@ class Parser(metaclass=_Parser): partition_by=partition, order=order, spec=spec, - alias=alias, + alias=window_alias, ) - def _parse_window_spec(self): + def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]: self._match(TokenType.BETWEEN) return { @@ -2543,7 +2768,9 @@ class Parser(metaclass=_Parser): "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text, } - def _parse_alias(self, this, explicit=False): + def _parse_alias( + self, this: t.Optional[exp.Expression], explicit: bool = False + ) -> t.Optional[exp.Expression]: any_token = self._match(TokenType.ALIAS) if explicit and not any_token: @@ -2565,63 +2792,74 @@ class Parser(metaclass=_Parser): return this - def _parse_id_var(self, any_token=True, tokens=None): + def _parse_id_var( + self, + any_token: bool = True, + tokens: t.Optional[t.Collection[TokenType]] = None, + prefix_tokens: t.Optional[t.Collection[TokenType]] = None, + ) -> t.Optional[exp.Expression]: identifier = self._parse_identifier() if identifier: return identifier + prefix = "" + + if prefix_tokens: + while self._match_set(prefix_tokens): + prefix += self._prev.text + if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): - return exp.Identifier(this=self._prev.text, quoted=False) + return exp.Identifier(this=prefix + self._prev.text, quoted=False) return None - def _parse_string(self): + def _parse_string(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STRING): return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev) return self._parse_placeholder() - def _parse_number(self): + def _parse_number(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NUMBER): return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev) return self._parse_placeholder() - def _parse_identifier(self): + def _parse_identifier(self) -> t.Optional[exp.Expression]: if self._match(TokenType.IDENTIFIER): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self, any_token=False): + def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]: if (any_token and self._advance_any()) or self._match(TokenType.VAR): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() - def _advance_any(self): + def _advance_any(self) -> t.Optional[Token]: if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: self._advance() return self._prev return None - def _parse_var_or_string(self): + def _parse_var_or_string(self) -> t.Optional[exp.Expression]: return self._parse_var() or self._parse_string() - def _parse_null(self): + def _parse_null(self) -> t.Optional[exp.Expression]: if self._match(TokenType.NULL): return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev) return None - def _parse_boolean(self): + def _parse_boolean(self) -> t.Optional[exp.Expression]: if self._match(TokenType.TRUE): return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev) if self._match(TokenType.FALSE): return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev) return None - def _parse_star(self): + def _parse_star(self) -> t.Optional[exp.Expression]: if self._match(TokenType.STAR): return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev) return None - def _parse_placeholder(self): + def _parse_placeholder(self) -> t.Optional[exp.Expression]: if self._match(TokenType.PLACEHOLDER): return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): @@ -2630,18 +2868,20 @@ class Parser(metaclass=_Parser): self._advance(-1) return None - def _parse_except(self): + def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.EXCEPT): return None return self._parse_wrapped_id_vars() - def _parse_replace(self): + def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]: if not self._match(TokenType.REPLACE): return None return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression())) - def _parse_csv(self, parse_method, sep=TokenType.COMMA): + def _parse_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[t.Optional[exp.Expression]]: parse_result = parse_method() items = [parse_result] if parse_result is not None else [] @@ -2655,7 +2895,9 @@ class Parser(metaclass=_Parser): return items - def _parse_tokens(self, parse_method, expressions): + def _parse_tokens( + self, parse_method: t.Callable, expressions: t.Dict + ) -> t.Optional[exp.Expression]: this = parse_method() while self._match_set(expressions): @@ -2668,22 +2910,29 @@ class Parser(metaclass=_Parser): return this - def _parse_wrapped_id_vars(self): + def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]: return self._parse_wrapped_csv(self._parse_id_var) - def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA): + def _parse_wrapped_csv( + self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA + ) -> t.List[t.Optional[exp.Expression]]: return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep)) - def _parse_wrapped(self, parse_method): + def _parse_wrapped(self, parse_method: t.Callable) -> t.Any: self._match_l_paren() parse_result = parse_method() self._match_r_paren() return parse_result - def _parse_select_or_expression(self): + def _parse_select_or_expression(self) -> t.Optional[exp.Expression]: return self._parse_select() or self._parse_expression() - def _parse_transaction(self): + def _parse_ddl_select(self) -> t.Optional[exp.Expression]: + return self._parse_set_operations( + self._parse_select(nested=True, parse_subquery_alias=False) + ) + + def _parse_transaction(self) -> exp.Expression: this = None if self._match_texts(self.TRANSACTION_KIND): this = self._prev.text @@ -2703,7 +2952,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.Transaction, this=this, modes=modes) - def _parse_commit_or_rollback(self): + def _parse_commit_or_rollback(self) -> exp.Expression: chain = None savepoint = None is_rollback = self._prev.token_type == TokenType.ROLLBACK @@ -2722,27 +2971,30 @@ class Parser(metaclass=_Parser): return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Commit, chain=chain) - def _parse_add_column(self): + def _parse_add_column(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("ADD"): return None self._match(TokenType.COLUMN) exists_column = self._parse_exists(not_=True) expression = self._parse_column_def(self._parse_field(any_token=True)) - expression.set("exists", exists_column) + + if expression: + expression.set("exists", exists_column) + return expression - def _parse_drop_column(self): + def _parse_drop_column(self) -> t.Optional[exp.Expression]: return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") - def _parse_alter(self): + def _parse_alter(self) -> t.Optional[exp.Expression]: if not self._match(TokenType.TABLE): return None exists = self._parse_exists() this = self._parse_table(schema=True) - actions = None + actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None if self._match_text_seq("ADD", advance=False): actions = self._parse_csv(self._parse_add_column) elif self._match_text_seq("DROP", advance=False): @@ -2770,24 +3022,24 @@ class Parser(metaclass=_Parser): actions = ensure_list(actions) return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) - def _parse_show(self): - parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) + def _parse_show(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore if parser: return parser(self) self._advance() return self.expression(exp.Show, this=self._prev.text.upper()) - def _default_parse_set_item(self): + def _default_parse_set_item(self) -> exp.Expression: return self.expression( exp.SetItem, this=self._parse_statement(), ) - def _parse_set_item(self): - parser = self._find_parser(self.SET_PARSERS, self._set_trie) + def _parse_set_item(self) -> t.Optional[exp.Expression]: + parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore return parser(self) if parser else self._default_parse_set_item() - def _parse_merge(self): + def _parse_merge(self) -> exp.Expression: self._match(TokenType.INTO) target = self._parse_table(schema=True) @@ -2835,10 +3087,12 @@ class Parser(metaclass=_Parser): expressions=whens, ) - def _parse_set(self): + def _parse_set(self) -> exp.Expression: return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item)) - def _find_parser(self, parsers, trie): + def _find_parser( + self, parsers: t.Dict[str, t.Callable], trie: t.Dict + ) -> t.Optional[t.Callable]: index = self._index this = [] while True: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index d9a4004..a0d69a7 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -3,6 +3,7 @@ from __future__ import annotations import abc import typing as t +import sqlglot from sqlglot import expressions as exp from sqlglot.errors import SchemaError from sqlglot.helper import dict_depth @@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): visible: t.Optional[t.Dict] = None, dialect: t.Optional[str] = None, ) -> None: - super().__init__(schema) - self.visible = visible or {} self.dialect = dialect + self.visible = visible or {} self._type_mapping_cache: t.Dict[str, exp.DataType] = {} + super().__init__(self._normalize(schema or {})) @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: @@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): } ) + def _normalize(self, schema: t.Dict) -> t.Dict: + """ + Converts all identifiers in the schema into lowercase, unless they're quoted. + + Args: + schema: the schema to normalize. + + Returns: + The normalized schema mapping. + """ + flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1) + + normalized_mapping: t.Dict = {} + for keys in flattened_schema: + columns = _nested_get(schema, *zip(keys, keys)) + assert columns is not None + + normalized_keys = [self._normalize_name(key) for key in keys] + for column_name, column_type in columns.items(): + _nested_set( + normalized_mapping, + normalized_keys + [self._normalize_name(column_name)], + column_type, + ) + + return normalized_mapping + def add_table( self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None ) -> None: @@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): ) self.mapping_trie = self._build_trie(self.mapping) + def _normalize_name(self, name: str) -> str: + try: + identifier: t.Optional[exp.Expression] = sqlglot.parse_one( + name, read=self.dialect, into=exp.Identifier + ) + except: + identifier = exp.to_identifier(name) + assert isinstance(identifier, exp.Identifier) + + if identifier.quoted: + return identifier.name + return identifier.name.lower() + def _depth(self) -> int: # The columns themselves are a mapping, but we don't want to include those return super()._depth() - 1 diff --git a/sqlglot/serde.py b/sqlglot/serde.py new file mode 100644 index 0000000..a47ffdb --- /dev/null +++ b/sqlglot/serde.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import expressions as exp + +if t.TYPE_CHECKING: + JSON = t.Union[dict, list, str, float, int, bool] + Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON] + + +def dump(node: Node) -> JSON: + """ + Recursively dump an AST into a JSON-serializable dict. + """ + if isinstance(node, list): + return [dump(i) for i in node] + if isinstance(node, exp.DataType.Type): + return { + "class": "DataType.Type", + "value": node.value, + } + if isinstance(node, exp.Expression): + klass = node.__class__.__qualname__ + if node.__class__.__module__ != exp.__name__: + klass = f"{node.__module__}.{klass}" + obj = { + "class": klass, + "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []}, + } + if node.type: + obj["type"] = node.type.sql() + if node.comments: + obj["comments"] = node.comments + return obj + return node + + +def load(obj: JSON) -> Node: + """ + Recursively load a dict (as returned by `dump`) into an AST. + """ + if isinstance(obj, list): + return [load(i) for i in obj] + if isinstance(obj, dict): + class_name = obj["class"] + + if class_name == "DataType.Type": + return exp.DataType.Type(obj["value"]) + + if "." in class_name: + module_path, class_name = class_name.rsplit(".", maxsplit=1) + module = __import__(module_path, fromlist=[class_name]) + else: + module = exp + + klass = getattr(module, class_name) + + expression = klass(**{k: load(v) for k, v in obj["args"].items()}) + type_ = obj.get("type") + if type_: + expression.type = exp.DataType.build(type_) + comments = obj.get("comments") + if comments: + expression.comments = load(comments) + return expression + return obj diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 0efa7d0..8e312a7 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -86,6 +86,7 @@ class TokenType(AutoName): VARBINARY = auto() JSON = auto() JSONB = auto() + TIME = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -181,6 +182,7 @@ class TokenType(AutoName): FUNCTION = auto() FROM = auto() GENERATED = auto() + GLOBAL = auto() GROUP_BY = auto() GROUPING_SETS = auto() HAVING = auto() @@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer): "FLOAT4": TokenType.FLOAT, "FLOAT8": TokenType.DOUBLE, "DOUBLE": TokenType.DOUBLE, + "DOUBLE PRECISION": TokenType.DOUBLE, "JSON": TokenType.JSON, "CHAR": TokenType.CHAR, "NCHAR": TokenType.NCHAR, @@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer): "BLOB": TokenType.VARBINARY, "BYTEA": TokenType.VARBINARY, "VARBINARY": TokenType.VARBINARY, + "TIME": TokenType.TIME, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, @@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer): COMMENTS = ["--", ("/*", "*/")] KEYWORD_TRIE = None # autofilled + IDENTIFIER_CAN_START_WITH_DIGIT = False + __slots__ = ( "sql", "size", @@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer): elif self._peek.upper() == "E" and not scientific: # type: ignore scientific += 1 self._advance() - elif self._peek.isalpha(): # type: ignore - self._add(TokenType.NUMBER) + elif self._peek.isidentifier(): # type: ignore + number_text = self._text literal = [] - while self._peek.isalpha(): # type: ignore + while self._peek.isidentifier(): # type: ignore literal.append(self._peek.upper()) # type: ignore self._advance() + literal = "".join(literal) # type: ignore token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore + if token_type: + self._add(TokenType.NUMBER, number_text) self._add(TokenType.DCOLON, "::") return self._add(token_type, literal) # type: ignore + elif self.IDENTIFIER_CAN_START_WITH_DIGIT: + return self._add(TokenType.VAR) + + self._add(TokenType.NUMBER, number_text) return self._advance(-len(literal)) else: return self._add(TokenType.NUMBER) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 99949a1..35ff75a 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: return expression +def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: + """ + Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. + This transforms removes the precision from parameterized types in expressions. + """ + return expression.transform( + lambda node: exp.DataType( + **{ + **node.args, + "expressions": [ + node_expression + for node_expression in node.expressions + if isinstance(node_expression, exp.DataType) + ], + } + ) + if isinstance(node, exp.DataType) + else node, + ) + + def preprocess( transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], to_sql: t.Callable[[Generator, exp.Expression], str], @@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable: UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))} +REMOVE_PRECISION_PARAMETERIZED_TYPES = { + exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql")) +} diff --git a/sqlglot/trie.py b/sqlglot/trie.py index fa2aaf1..f3b1c38 100644 --- a/sqlglot/trie.py +++ b/sqlglot/trie.py @@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]: Returns: A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value` - is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). + is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`). """ if not key: return (0, trie) -- cgit v1.2.3