diff options
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dialects/__init__.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 13 | ||||
-rw-r--r-- | sqlglot/dialects/dialect.py | 25 | ||||
-rw-r--r-- | sqlglot/dialects/doris.py | 65 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 26 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 14 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 8 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/redshift.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/spark.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/starrocks.py | 5 | ||||
-rw-r--r-- | sqlglot/executor/__init__.py | 10 | ||||
-rw-r--r-- | sqlglot/executor/table.py | 34 | ||||
-rw-r--r-- | sqlglot/expressions.py | 58 | ||||
-rw-r--r-- | sqlglot/generator.py | 28 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 108 | ||||
-rw-r--r-- | sqlglot/parser.py | 118 | ||||
-rw-r--r-- | sqlglot/schema.py | 41 | ||||
-rw-r--r-- | sqlglot/tokens.py | 13 |
20 files changed, 465 insertions, 127 deletions
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index fc34262..8212669 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -60,6 +60,7 @@ from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.doris import Doris from sqlglot.dialects.drill import Drill from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.hive import Hive diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index e6b7743..cfde5fd 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -37,17 +37,22 @@ class ClickHouse(Dialect): "ATTACH": TokenType.COMMAND, "DATETIME64": TokenType.DATETIME64, "DICTIONARY": TokenType.DICTIONARY, + "ENUM": TokenType.ENUM, + "ENUM8": TokenType.ENUM8, + "ENUM16": TokenType.ENUM16, "FINAL": TokenType.FINAL, + "FIXEDSTRING": TokenType.FIXEDSTRING, "FLOAT32": TokenType.FLOAT, "FLOAT64": TokenType.DOUBLE, "GLOBAL": TokenType.GLOBAL, - "INT128": TokenType.INT128, "INT16": TokenType.SMALLINT, "INT256": TokenType.INT256, "INT32": TokenType.INT, "INT64": TokenType.BIGINT, "INT8": TokenType.TINYINT, + "LOWCARDINALITY": TokenType.LOWCARDINALITY, "MAP": TokenType.MAP, + "NESTED": TokenType.NESTED, "TUPLE": TokenType.STRUCT, "UINT128": TokenType.UINT128, "UINT16": TokenType.USMALLINT, @@ -294,11 +299,17 @@ class ClickHouse(Dialect): exp.DataType.Type.BIGINT: "Int64", exp.DataType.Type.DATETIME64: "DateTime64", exp.DataType.Type.DOUBLE: "Float64", + exp.DataType.Type.ENUM: "Enum", + exp.DataType.Type.ENUM8: "Enum8", + exp.DataType.Type.ENUM16: "Enum16", + exp.DataType.Type.FIXEDSTRING: "FixedString", exp.DataType.Type.FLOAT: "Float32", exp.DataType.Type.INT: "Int32", exp.DataType.Type.INT128: "Int128", exp.DataType.Type.INT256: "Int256", + exp.DataType.Type.LOWCARDINALITY: "LowCardinality", exp.DataType.Type.MAP: "Map", + exp.DataType.Type.NESTED: "Nested", exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.SMALLINT: "Int16", exp.DataType.Type.STRUCT: "Tuple", diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1d0584c..132496f 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -39,6 +39,7 @@ class Dialects(str, Enum): TERADATA = "teradata" TRINO = "trino" TSQL = "tsql" + Doris = "doris" class _Dialect(type): @@ -121,7 +122,7 @@ class _Dialect(type): if hasattr(subclass, name): setattr(subclass, name, value) - if not klass.STRICT_STRING_CONCAT: + if not klass.STRICT_STRING_CONCAT and klass.DPIPE_IS_STRING_CONCAT: klass.parser_class.BITWISE[TokenType.DPIPE] = exp.SafeDPipe klass.generator_class.can_identify = klass.can_identify @@ -146,6 +147,9 @@ class Dialect(metaclass=_Dialect): # Determines whether or not an unquoted identifier can start with a digit IDENTIFIERS_CAN_START_WITH_DIGIT = False + # Determines whether or not the DPIPE token ('||') is a string concatenation operator + DPIPE_IS_STRING_CONCAT = True + # Determines whether or not CONCAT's arguments must be strings STRICT_STRING_CONCAT = False @@ -460,6 +464,20 @@ def format_time_lambda( return _format_time +def time_format( + dialect: DialectType = None, +) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: + def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: + """ + Returns the time format for a given expression, unless it's equivalent + to the default time format of the dialect of interest. + """ + time_format = self.format_time(expression) + return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None + + return _time_format + + def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str: """ In Hive and Spark, the PARTITIONED BY property acts as an extension of a table's schema. When the @@ -699,3 +717,8 @@ def simplify_literal(expression: E) -> E: def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) + + +# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects +def parse_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: + return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py new file mode 100644 index 0000000..160c23c --- /dev/null +++ b/sqlglot/dialects/doris.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from sqlglot import exp +from sqlglot.dialects.dialect import ( + approx_count_distinct_sql, + arrow_json_extract_sql, + parse_timestamp_trunc, + rename_func, + time_format, +) +from sqlglot.dialects.mysql import MySQL + + +class Doris(MySQL): + DATE_FORMAT = "'yyyy-MM-dd'" + DATEINT_FORMAT = "'yyyyMMdd'" + TIME_FORMAT = "'yyyy-MM-dd HH:mm:ss'" + + class Parser(MySQL.Parser): + FUNCTIONS = { + **MySQL.Parser.FUNCTIONS, + "DATE_TRUNC": parse_timestamp_trunc, + "REGEXP": exp.RegexpLike.from_arg_list, + } + + class Generator(MySQL.Generator): + CAST_MAPPING = {} + + TYPE_MAPPING = { + **MySQL.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "DATETIME", + } + + TRANSFORMS = { + **MySQL.Generator.TRANSFORMS, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.Coalesce: rename_func("NVL"), + exp.CurrentTimestamp: lambda *_: "NOW()", + exp.DateTrunc: lambda self, e: self.func( + "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" + ), + exp.JSONExtractScalar: arrow_json_extract_sql, + exp.JSONExtract: arrow_json_extract_sql, + exp.RegexpLike: rename_func("REGEXP"), + exp.RegexpSplit: rename_func("SPLIT_BY_STRING"), + exp.SetAgg: rename_func("COLLECT_SET"), + exp.StrToUnix: lambda self, e: f"UNIX_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Split: rename_func("SPLIT_BY_STRING"), + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.ToChar: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", # Only for day level + exp.TsOrDsToDate: lambda self, e: self.func("TO_DATE", e.this), + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimestampTrunc: lambda self, e: self.func( + "DATE_TRUNC", e.this, "'" + e.text("unit") + "'" + ), + exp.UnixToStr: lambda self, e: self.func( + "FROM_UNIXTIME", e.this, time_format("doris")(self, e) + ), + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.Map: rename_func("ARRAY_MAP"), + } diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 5428e86..8253b52 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -89,6 +89,11 @@ def _struct_sql(self: generator.Generator, expression: exp.Struct) -> str: def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.is_type("array"): return f"{self.expressions(expression, flat=True)}[]" + + # Type TIMESTAMP / TIME WITH TIME ZONE does not support any modifiers + if expression.is_type("timestamptz", "timetz"): + return expression.this.value + return self.datatype_sql(expression) @@ -110,14 +115,14 @@ class DuckDB(Dialect): "//": TokenType.DIV, "ATTACH": TokenType.COMMAND, "BINARY": TokenType.VARBINARY, - "BPCHAR": TokenType.TEXT, "BITSTRING": TokenType.BIT, + "BPCHAR": TokenType.TEXT, "CHAR": TokenType.TEXT, "CHARACTER VARYING": TokenType.TEXT, "EXCLUDE": TokenType.EXCEPT, + "HUGEINT": TokenType.INT128, "INT1": TokenType.TINYINT, "LOGICAL": TokenType.BOOLEAN, - "NUMERIC": TokenType.DOUBLE, "PIVOT_WIDER": TokenType.PIVOT, "SIGNED": TokenType.INT, "STRING": TokenType.VARCHAR, @@ -186,6 +191,22 @@ class DuckDB(Dialect): TokenType.UTINYINT, } + def _parse_types( + self, check_func: bool = False, schema: bool = False + ) -> t.Optional[exp.Expression]: + this = super()._parse_types(check_func=check_func, schema=schema) + + # DuckDB treats NUMERIC and DECIMAL without precision as DECIMAL(18, 3) + # See: https://duckdb.org/docs/sql/data_types/numeric + if ( + isinstance(this, exp.DataType) + and this.is_type("numeric", "decimal") + and not this.expressions + ): + return exp.DataType.build("DECIMAL(18, 3)") + + return this + def _pivot_column_names(self, aggregations: t.List[exp.Expression]) -> t.List[str]: if len(aggregations) == 1: return super()._pivot_column_names(aggregations) @@ -231,6 +252,7 @@ class DuckDB(Dialect): exp.Encode: lambda self, e: encode_decode_sql(self, e, "ENCODE", replace=False), exp.Explode: rename_func("UNNEST"), exp.IntDiv: lambda self, e: self.binary(e, "//"), + exp.IsNan: rename_func("ISNAN"), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONFormat: _json_format_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index aa4d845..584acc6 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -23,6 +23,7 @@ from sqlglot.dialects.dialect import ( right_to_substring_sql, strposition_to_locate_sql, struct_extract_sql, + time_format, timestrtotime_sql, var_map_sql, ) @@ -113,7 +114,7 @@ def _property_sql(self: generator.Generator, expression: exp.Property) -> str: def _str_to_unix_sql(self: generator.Generator, expression: exp.StrToUnix) -> str: - return self.func("UNIX_TIMESTAMP", expression.this, _time_format(self, expression)) + return self.func("UNIX_TIMESTAMP", expression.this, time_format("hive")(self, expression)) def _str_to_date_sql(self: generator.Generator, expression: exp.StrToDate) -> str: @@ -132,15 +133,6 @@ def _str_to_time_sql(self: generator.Generator, expression: exp.StrToTime) -> st return f"CAST({this} AS TIMESTAMP)" -def _time_format( - self: generator.Generator, expression: exp.UnixToStr | exp.StrToUnix -) -> t.Optional[str]: - time_format = self.format_time(expression) - if time_format == Hive.TIME_FORMAT: - return None - return time_format - - def _time_to_str(self: generator.Generator, expression: exp.TimeToStr) -> str: this = self.sql(expression, "this") time_format = self.format_time(expression) @@ -439,7 +431,7 @@ class Hive(Dialect): exp.TsOrDsToDate: _to_date_sql, exp.TryCast: no_trycast_sql, exp.UnixToStr: lambda self, e: self.func( - "FROM_UNIXTIME", e.this, _time_format(self, e) + "FROM_UNIXTIME", e.this, time_format("hive")(self, e) ), exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 3cd99e7..9ab4ce8 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -94,6 +94,7 @@ def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | e class MySQL(Dialect): TIME_FORMAT = "'%Y-%m-%d %T'" + DPIPE_IS_STRING_CONCAT = False # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions TIME_MAPPING = { @@ -103,7 +104,6 @@ class MySQL(Dialect): "%h": "%I", "%i": "%M", "%s": "%S", - "%S": "%S", "%u": "%W", "%k": "%-H", "%l": "%-I", @@ -196,8 +196,14 @@ class MySQL(Dialect): **parser.Parser.CONJUNCTION, TokenType.DAMP: exp.And, TokenType.XOR: exp.Xor, + TokenType.DPIPE: exp.Or, } + # MySQL uses || as a synonym to the logical OR operator + # https://dev.mysql.com/doc/refman/8.0/en/logical-operators.html#operator_or + BITWISE = parser.Parser.BITWISE.copy() + BITWISE.pop(TokenType.DPIPE) + TABLE_ALIAS_TOKENS = ( parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS ) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index ca44b70..73ca4e5 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -16,6 +16,7 @@ from sqlglot.dialects.dialect import ( no_pivot_sql, no_tablesample_sql, no_trycast_sql, + parse_timestamp_trunc, rename_func, simplify_literal, str_position_sql, @@ -286,9 +287,7 @@ class Postgres(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, - "DATE_TRUNC": lambda args: exp.TimestampTrunc( - this=seq_get(args, 1), unit=seq_get(args, 0) - ), + "DATE_TRUNC": parse_timestamp_trunc, "GENERATE_SERIES": _generate_series, "NOW": exp.CurrentTimestamp.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 291b478..078da0b 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -32,13 +32,6 @@ def _approx_distinct_sql(self: generator.Generator, expression: exp.ApproxDistin return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" -def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: - sql = self.datatype_sql(expression) - if expression.is_type("timestamptz"): - sql = f"{sql} WITH TIME ZONE" - return sql - - def _explode_to_unnest_sql(self: generator.Generator, expression: exp.Lateral) -> str: if isinstance(expression.this, (exp.Explode, exp.Posexplode)): expression = expression.copy() @@ -231,6 +224,7 @@ class Presto(Dialect): TABLE_HINTS = False QUERY_HINTS = False IS_BOOL_ALLOWED = False + TZ_TO_WITH_TIME_ZONE = True STRUCT_DELIMITER = ("(", ")") PROPERTIES_LOCATION = { @@ -245,6 +239,7 @@ class Presto(Dialect): exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.TIMETZ: "TIME", exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", exp.DataType.Type.STRUCT: "ROW", } @@ -265,7 +260,6 @@ class Presto(Dialect): exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", exp.Cast: transforms.preprocess([transforms.epoch_cast_to_ts]), exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.DataType: _datatype_sql, exp.DateAdd: lambda self, e: self.func( "DATE_ADD", exp.Literal.string(e.text("unit") or "day"), e.expression, e.this ), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index cdb8d0d..30731e1 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -85,8 +85,6 @@ class Redshift(Postgres): "HLLSKETCH": TokenType.HLLSKETCH, "SUPER": TokenType.SUPER, "SYSDATE": TokenType.CURRENT_TIMESTAMP, - "TIME": TokenType.TIMESTAMP, - "TIMETZ": TokenType.TIMESTAMPTZ, "TOP": TokenType.TOP, "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, @@ -101,12 +99,15 @@ class Redshift(Postgres): RENAME_TABLE_WITH_DB = False QUERY_HINTS = False VALUES_AS_TABLE = False + TZ_TO_WITH_TIME_ZONE = True TYPE_MAPPING = { **Postgres.Generator.TYPE_MAPPING, exp.DataType.Type.BINARY: "VARBYTE", - exp.DataType.Type.VARBINARY: "VARBYTE", exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.TIMETZ: "TIME", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.VARBINARY: "VARBYTE", } PROPERTIES_LOCATION = { diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index b9aaa66..7c8982b 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -52,6 +52,9 @@ class Spark(Spark2): TRANSFORMS = { **Spark2.Generator.TRANSFORMS, exp.StartsWith: rename_func("STARTSWITH"), + exp.TimestampAdd: lambda self, e: self.func( + "DATEADD", e.args.get("unit") or "DAY", e.expression, e.this + ), } TRANSFORMS.pop(exp.DateDiff) TRANSFORMS.pop(exp.Group) diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 4f6183c..2dba1c1 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -4,6 +4,7 @@ from sqlglot import exp from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_sql, + parse_timestamp_trunc, rename_func, ) from sqlglot.dialects.mysql import MySQL @@ -14,9 +15,7 @@ class StarRocks(MySQL): class Parser(MySQL.Parser): FUNCTIONS = { **MySQL.Parser.FUNCTIONS, - "DATE_TRUNC": lambda args: exp.TimestampTrunc( - this=seq_get(args, 1), unit=seq_get(args, 0) - ), + "DATE_TRUNC": parse_timestamp_trunc, "DATEDIFF": lambda args: exp.DateDiff( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py index 017d5bc..304981b 100644 --- a/sqlglot/executor/__init__.py +++ b/sqlglot/executor/__init__.py @@ -28,6 +28,11 @@ if t.TYPE_CHECKING: from sqlglot.schema import Schema +PYTHON_TYPE_TO_SQLGLOT = { + "dict": "MAP", +} + + def execute( sql: str | Expression, schema: t.Optional[t.Dict | Schema] = None, @@ -50,7 +55,7 @@ def execute( Returns: Simple columnar data structure. """ - tables_ = ensure_tables(tables) + tables_ = ensure_tables(tables, dialect=read) if not schema: schema = {} @@ -61,7 +66,8 @@ def execute( assert table is not None for column in table.columns: - nested_set(schema, [*keys, column], type(table[0][column]).__name__) + py_type = type(table[0][column]).__name__ + nested_set(schema, [*keys, column], PYTHON_TYPE_TO_SQLGLOT.get(py_type) or py_type) schema = ensure_schema(schema, dialect=read) diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py index 27e3e5e..74b9b7c 100644 --- a/sqlglot/executor/table.py +++ b/sqlglot/executor/table.py @@ -2,8 +2,9 @@ from __future__ import annotations import typing as t +from sqlglot.dialects.dialect import DialectType from sqlglot.helper import dict_depth -from sqlglot.schema import AbstractMappingSchema +from sqlglot.schema import AbstractMappingSchema, normalize_name class Table: @@ -108,26 +109,37 @@ class Tables(AbstractMappingSchema[Table]): pass -def ensure_tables(d: t.Optional[t.Dict]) -> Tables: - return Tables(_ensure_tables(d)) +def ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> Tables: + return Tables(_ensure_tables(d, dialect=dialect)) -def _ensure_tables(d: t.Optional[t.Dict]) -> t.Dict: +def _ensure_tables(d: t.Optional[t.Dict], dialect: DialectType = None) -> t.Dict: if not d: return {} depth = dict_depth(d) - if depth > 1: - return {k: _ensure_tables(v) for k, v in d.items()} + return { + normalize_name(k, dialect=dialect, is_table=True): _ensure_tables(v, dialect=dialect) + for k, v in d.items() + } result = {} - for name, table in d.items(): + for table_name, table in d.items(): + table_name = normalize_name(table_name, dialect=dialect) + if isinstance(table, Table): - result[name] = table + result[table_name] = table else: - columns = tuple(table[0]) if table else () - rows = [tuple(row[c] for c in columns) for row in table] - result[name] = Table(columns=columns, rows=rows) + table = [ + { + normalize_name(column_name, dialect=dialect): value + for column_name, value in row.items() + } + for row in table + ] + column_names = tuple(column_name for column_name in table[0]) if table else () + rows = [tuple(row[name] for name in column_names) for row in table] + result[table_name] = Table(columns=column_names, rows=rows) return result diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index c207751..57b8bfa 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -3309,6 +3309,7 @@ class Pivot(Expression): "using": False, "group": False, "columns": False, + "include_nulls": False, } @@ -3397,23 +3398,16 @@ class DataType(Expression): BOOLEAN = auto() CHAR = auto() DATE = auto() + DATEMULTIRANGE = auto() + DATERANGE = auto() DATETIME = auto() DATETIME64 = auto() - ENUM = auto() - INT4RANGE = auto() - INT4MULTIRANGE = auto() - INT8RANGE = auto() - INT8MULTIRANGE = auto() - NUMRANGE = auto() - NUMMULTIRANGE = auto() - TSRANGE = auto() - TSMULTIRANGE = auto() - TSTZRANGE = auto() - TSTZMULTIRANGE = auto() - DATERANGE = auto() - DATEMULTIRANGE = auto() DECIMAL = auto() DOUBLE = auto() + ENUM = auto() + ENUM8 = auto() + ENUM16 = auto() + FIXEDSTRING = auto() FLOAT = auto() GEOGRAPHY = auto() GEOMETRY = auto() @@ -3421,23 +3415,31 @@ class DataType(Expression): HSTORE = auto() IMAGE = auto() INET = auto() - IPADDRESS = auto() - IPPREFIX = auto() INT = auto() INT128 = auto() INT256 = auto() + INT4MULTIRANGE = auto() + INT4RANGE = auto() + INT8MULTIRANGE = auto() + INT8RANGE = auto() INTERVAL = auto() + IPADDRESS = auto() + IPPREFIX = auto() JSON = auto() JSONB = auto() LONGBLOB = auto() LONGTEXT = auto() + LOWCARDINALITY = auto() MAP = auto() MEDIUMBLOB = auto() MEDIUMTEXT = auto() MONEY = auto() NCHAR = auto() + NESTED = auto() NULL = auto() NULLABLE = auto() + NUMMULTIRANGE = auto() + NUMRANGE = auto() NVARCHAR = auto() OBJECT = auto() ROWVERSION = auto() @@ -3450,19 +3452,24 @@ class DataType(Expression): SUPER = auto() TEXT = auto() TIME = auto() + TIMETZ = auto() TIMESTAMP = auto() - TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() + TIMESTAMPTZ = auto() TINYINT = auto() + TSMULTIRANGE = auto() + TSRANGE = auto() + TSTZMULTIRANGE = auto() + TSTZRANGE = auto() UBIGINT = auto() UINT = auto() - USMALLINT = auto() - UTINYINT = auto() - UNKNOWN = auto() # Sentinel value, useful for type annotation UINT128 = auto() UINT256 = auto() UNIQUEIDENTIFIER = auto() + UNKNOWN = auto() # Sentinel value, useful for type annotation USERDEFINED = "USER-DEFINED" + USMALLINT = auto() + UTINYINT = auto() UUID = auto() VARBINARY = auto() VARCHAR = auto() @@ -3495,6 +3502,7 @@ class DataType(Expression): TEMPORAL_TYPES = { Type.TIME, + Type.TIMETZ, Type.TIMESTAMP, Type.TIMESTAMPTZ, Type.TIMESTAMPLTZ, @@ -3858,6 +3866,18 @@ class TimeUnit(Expression): super().__init__(**args) +# https://www.oracletutorial.com/oracle-basics/oracle-interval/ +# https://trino.io/docs/current/language/types.html#interval-year-to-month +class IntervalYearToMonthSpan(Expression): + arg_types = {} + + +# https://www.oracletutorial.com/oracle-basics/oracle-interval/ +# https://trino.io/docs/current/language/types.html#interval-day-to-second +class IntervalDayToSecondSpan(Expression): + arg_types = {} + + class Interval(TimeUnit): arg_types = {"this": False, "unit": False} diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 95db795..f8d7d68 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -71,6 +71,8 @@ class Generator: exp.ExternalProperty: lambda self, e: "EXTERNAL", exp.HeapProperty: lambda self, e: "HEAP", exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}", + exp.IntervalDayToSecondSpan: "DAY TO SECOND", + exp.IntervalYearToMonthSpan: "YEAR TO MONTH", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG", @@ -166,6 +168,9 @@ class Generator: # Whether or not to generate an unquoted value for EXTRACT's date part argument EXTRACT_ALLOWS_QUOTES = True + # Whether or not TIMETZ / TIMESTAMPTZ will be generated using the "WITH TIME ZONE" syntax + TZ_TO_WITH_TIME_ZONE = False + # https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax SELECT_KINDS: t.Tuple[str, ...] = ("STRUCT", "VALUE") @@ -271,10 +276,12 @@ class Generator: # Expressions whose comments are separated from them for better formatting WITH_SEPARATED_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Create, exp.Delete, exp.Drop, exp.From, exp.Insert, + exp.Join, exp.Select, exp.Update, exp.Where, @@ -831,14 +838,17 @@ class Generator: def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this + type_sql = ( self.TYPE_MAPPING.get(type_value, type_value.value) if isinstance(type_value, exp.DataType.Type) else type_value ) + nested = "" interior = self.expressions(expression, flat=True) values = "" + if interior: if expression.args.get("nested"): nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" @@ -846,10 +856,19 @@ class Generator: delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")") values = self.expressions(expression, key="values", flat=True) values = f"{delimiters[0]}{values}{delimiters[1]}" + elif type_value == exp.DataType.Type.INTERVAL: + nested = f" {interior}" else: nested = f"({interior})" - return f"{type_sql}{nested}{values}" + type_sql = f"{type_sql}{nested}{values}" + if self.TZ_TO_WITH_TIME_ZONE and type_value in ( + exp.DataType.Type.TIMETZ, + exp.DataType.Type.TIMESTAMPTZ, + ): + type_sql = f"{type_sql} WITH TIME ZONE" + + return type_sql def directory_sql(self, expression: exp.Directory) -> str: local = "LOCAL " if expression.args.get("local") else "" @@ -1288,7 +1307,12 @@ class Generator: unpivot = expression.args.get("unpivot") direction = "UNPIVOT" if unpivot else "PIVOT" field = self.sql(expression, "field") - return f"{direction}({expressions} FOR {field}){alias}" + include_nulls = expression.args.get("include_nulls") + if include_nulls is not None: + nulls = " INCLUDE NULLS " if include_nulls else " EXCLUDE NULLS " + else: + nulls = "" + return f"{direction}{nulls}({expressions} FOR {field}){alias}" def tuple_sql(self, expression: exp.Tuple) -> str: return f"({self.expressions(expression, flat=True)})" diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index e247f58..e550603 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -54,11 +54,17 @@ def simplify(expression): def _simplify(expression, root=True): if expression.meta.get(FINAL): return expression + + # Pre-order transformations node = expression node = rewrite_between(node) node = uniq_sort(node, generate, root) node = absorb_and_eliminate(node, root) + node = simplify_concat(node) + exp.replace_children(node, lambda e: _simplify(e, False)) + + # Post-order transformations node = simplify_not(node) node = flatten(node) node = simplify_connectors(node, root) @@ -66,8 +72,11 @@ def simplify(expression): node.parent = expression.parent node = simplify_literals(node, root) node = simplify_parens(node) + node = simplify_coalesce(node) + if root: expression.replace(node) + return node expression = while_changing(expression, _simplify) @@ -184,6 +193,7 @@ COMPARISONS = ( *GT_GTE, exp.EQ, exp.NEQ, + exp.Is, ) INVERSE_COMPARISONS = { @@ -430,6 +440,103 @@ def simplify_parens(expression): return expression +CONSTANTS = ( + exp.Literal, + exp.Boolean, + exp.Null, +) + + +def simplify_coalesce(expression): + # COALESCE(x) -> x + if ( + isinstance(expression, exp.Coalesce) + and not expression.expressions + # COALESCE is also used as a Spark partitioning hint + and not isinstance(expression.parent, exp.Hint) + ): + return expression.this + + if not isinstance(expression, COMPARISONS): + return expression + + if isinstance(expression.left, exp.Coalesce): + coalesce = expression.left + other = expression.right + elif isinstance(expression.right, exp.Coalesce): + coalesce = expression.right + other = expression.left + else: + return expression + + # This transformation is valid for non-constants, + # but it really only does anything if they are both constants. + if not isinstance(other, CONSTANTS): + return expression + + # Find the first constant arg + for arg_index, arg in enumerate(coalesce.expressions): + if isinstance(arg, CONSTANTS): + break + else: + return expression + + coalesce.set("expressions", coalesce.expressions[:arg_index]) + + # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, + # since we already remove COALESCE at the top of this function. + coalesce = coalesce if coalesce.expressions else coalesce.this + + # This expression is more complex than when we started, but it will get simplified further + return exp.or_( + exp.and_( + coalesce.is_(exp.null()).not_(copy=False), + expression.copy(), + copy=False, + ), + exp.and_( + coalesce.is_(exp.null()), + type(expression)(this=arg.copy(), expression=other.copy()), + copy=False, + ), + copy=False, + ) + + +CONCATS = (exp.Concat, exp.DPipe) +SAFE_CONCATS = (exp.SafeConcat, exp.SafeDPipe) + + +def simplify_concat(expression): + """Reduces all groups that contain string literals by concatenating them.""" + if not isinstance(expression, CONCATS) or isinstance(expression, exp.ConcatWs): + return expression + + new_args = [] + for is_string_group, group in itertools.groupby( + expression.expressions or expression.flatten(), lambda e: e.is_string + ): + if is_string_group: + new_args.append(exp.Literal.string("".join(string.name for string in group))) + else: + new_args.extend(group) + + # Ensures we preserve the right concat type, i.e. whether it's "safe" or not + concat_type = exp.SafeConcat if isinstance(expression, SAFE_CONCATS) else exp.Concat + return new_args[0] if len(new_args) == 1 else concat_type(expressions=new_args) + + +# CROSS joins result in an empty table if the right table is empty. +# So we can only simplify certain types of joins to CROSS. +# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x +JOINS = { + ("", ""), + ("", "INNER"), + ("RIGHT", ""), + ("RIGHT", "OUTER"), +} + + def remove_where_true(expression): for where in expression.find_all(exp.Where): if always_true(where.this): @@ -439,6 +546,7 @@ def remove_where_true(expression): always_true(join.args.get("on")) and not join.args.get("using") and not join.args.get("method") + and (join.side, join.kind) in JOINS ): join.set("on", None) join.set("side", None) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 35a1744..3db4453 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -102,15 +102,23 @@ class Parser(metaclass=_Parser): TokenType.CURRENT_USER: exp.CurrentUser, } + STRUCT_TYPE_TOKENS = { + TokenType.NESTED, + TokenType.STRUCT, + } + NESTED_TYPE_TOKENS = { TokenType.ARRAY, + TokenType.LOWCARDINALITY, TokenType.MAP, TokenType.NULLABLE, - TokenType.STRUCT, + *STRUCT_TYPE_TOKENS, } ENUM_TYPE_TOKENS = { TokenType.ENUM, + TokenType.ENUM8, + TokenType.ENUM16, } TYPE_TOKENS = { @@ -128,6 +136,7 @@ class Parser(metaclass=_Parser): TokenType.UINT128, TokenType.INT256, TokenType.UINT256, + TokenType.FIXEDSTRING, TokenType.FLOAT, TokenType.DOUBLE, TokenType.CHAR, @@ -145,6 +154,7 @@ class Parser(metaclass=_Parser): TokenType.JSONB, TokenType.INTERVAL, TokenType.TIME, + TokenType.TIMETZ, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, @@ -187,7 +197,7 @@ class Parser(metaclass=_Parser): TokenType.INET, TokenType.IPADDRESS, TokenType.IPPREFIX, - TokenType.ENUM, + *ENUM_TYPE_TOKENS, *NESTED_TYPE_TOKENS, } @@ -384,11 +394,16 @@ class Parser(metaclass=_Parser): TokenType.STAR: exp.Mul, } - TIMESTAMPS = { + TIMES = { TokenType.TIME, + TokenType.TIMETZ, + } + + TIMESTAMPS = { TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, TokenType.TIMESTAMPLTZ, + *TIMES, } SET_OPERATIONS = { @@ -1165,6 +1180,8 @@ class Parser(metaclass=_Parser): def _parse_create(self) -> exp.Create | exp.Command: # Note: this can't be None because we've matched a statement parser start = self._prev + comments = self._prev_comments + replace = start.text.upper() == "REPLACE" or self._match_pair( TokenType.OR, TokenType.REPLACE ) @@ -1273,6 +1290,7 @@ class Parser(metaclass=_Parser): return self.expression( exp.Create, + comments=comments, this=this, kind=create_token.text, replace=replace, @@ -2338,7 +2356,8 @@ class Parser(metaclass=_Parser): kwargs["this"].set("joins", joins) - return self.expression(exp.Join, **kwargs) + comments = [c for token in (method, side, kind) if token for c in token.comments] + return self.expression(exp.Join, comments=comments, **kwargs) def _parse_index( self, @@ -2619,11 +2638,18 @@ class Parser(metaclass=_Parser): def _parse_pivot(self) -> t.Optional[exp.Pivot]: index = self._index + include_nulls = None if self._match(TokenType.PIVOT): unpivot = False elif self._match(TokenType.UNPIVOT): unpivot = True + + # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-qry-select-unpivot.html#syntax + if self._match_text_seq("INCLUDE", "NULLS"): + include_nulls = True + elif self._match_text_seq("EXCLUDE", "NULLS"): + include_nulls = False else: return None @@ -2654,7 +2680,13 @@ class Parser(metaclass=_Parser): self._match_r_paren() - pivot = self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot) + pivot = self.expression( + exp.Pivot, + expressions=expressions, + field=field, + unpivot=unpivot, + include_nulls=include_nulls, + ) if not self._match_set((TokenType.PIVOT, TokenType.UNPIVOT), advance=False): pivot.set("alias", self._parse_table_alias()) @@ -3096,7 +3128,7 @@ class Parser(metaclass=_Parser): return self.expression(exp.PseudoType, this=self._prev.text) nested = type_token in self.NESTED_TYPE_TOKENS - is_struct = type_token == TokenType.STRUCT + is_struct = type_token in self.STRUCT_TYPE_TOKENS expressions = None maybe_func = False @@ -3108,7 +3140,7 @@ class Parser(metaclass=_Parser): lambda: self._parse_types(check_func=check_func, schema=schema) ) elif type_token in self.ENUM_TYPE_TOKENS: - expressions = self._parse_csv(self._parse_primary) + expressions = self._parse_csv(self._parse_equality) else: expressions = self._parse_csv(self._parse_type_size) @@ -3118,29 +3150,9 @@ class Parser(metaclass=_Parser): maybe_func = True - if self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType( - this=exp.DataType.Type.ARRAY, - expressions=[ - exp.DataType( - this=exp.DataType.Type[type_token.value], - expressions=expressions, - nested=nested, - ) - ], - nested=True, - ) - - while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): - this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) - - return this - - if self._match(TokenType.L_BRACKET): - self._retreat(index) - return None - + this: t.Optional[exp.Expression] = None values: t.Optional[t.List[t.Optional[exp.Expression]]] = None + if nested and self._match(TokenType.LT): if is_struct: expressions = self._parse_csv(self._parse_struct_types) @@ -3156,23 +3168,35 @@ class Parser(metaclass=_Parser): values = self._parse_csv(self._parse_conjunction) self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN)) - value: t.Optional[exp.Expression] = None if type_token in self.TIMESTAMPS: if self._match_text_seq("WITH", "TIME", "ZONE"): maybe_func = False - value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions) + tz_type = ( + exp.DataType.Type.TIMETZ + if type_token in self.TIMES + else exp.DataType.Type.TIMESTAMPTZ + ) + this = exp.DataType(this=tz_type, expressions=expressions) elif self._match_text_seq("WITH", "LOCAL", "TIME", "ZONE"): maybe_func = False - value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) + this = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions) elif self._match_text_seq("WITHOUT", "TIME", "ZONE"): maybe_func = False elif type_token == TokenType.INTERVAL: - unit = self._parse_var() + if self._match_text_seq("YEAR", "TO", "MONTH"): + span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()] + elif self._match_text_seq("DAY", "TO", "SECOND"): + span = [exp.IntervalDayToSecondSpan()] + else: + span = None + unit = not span and self._parse_var() if not unit: - value = self.expression(exp.DataType, this=exp.DataType.Type.INTERVAL) + this = self.expression( + exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span + ) else: - value = self.expression(exp.Interval, unit=unit) + this = self.expression(exp.Interval, unit=unit) if maybe_func and check_func: index2 = self._index @@ -3184,16 +3208,19 @@ class Parser(metaclass=_Parser): self._retreat(index2) - if value: - return value + if not this: + this = exp.DataType( + this=exp.DataType.Type[type_token.value], + expressions=expressions, + nested=nested, + values=values, + prefix=prefix, + ) - return exp.DataType( - this=exp.DataType.Type[type_token.value], - expressions=expressions, - nested=nested, - values=values, - prefix=prefix, - ) + while self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET): + this = exp.DataType(this=exp.DataType.Type.ARRAY, expressions=[this], nested=True) + + return this def _parse_struct_types(self) -> t.Optional[exp.Expression]: this = self._parse_type() or self._parse_id_var() @@ -3738,6 +3765,7 @@ class Parser(metaclass=_Parser): ifs = [] default = None + comments = self._prev_comments expression = self._parse_conjunction() while self._match(TokenType.WHEN): @@ -3753,7 +3781,7 @@ class Parser(metaclass=_Parser): self.raise_error("Expected END after CASE", self._prev) return self._parse_window( - self.expression(exp.Case, this=expression, ifs=ifs, default=default) + self.expression(exp.Case, comments=comments, this=expression, ifs=ifs, default=default) ) def _parse_if(self) -> t.Optional[exp.Expression]: diff --git a/sqlglot/schema.py b/sqlglot/schema.py index 7a3c88b..f028f5a 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -372,21 +372,12 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): is_table: bool = False, normalize: t.Optional[bool] = None, ) -> str: - dialect = dialect or self.dialect - normalize = self.normalize if normalize is None else normalize - - try: - identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) - except ParseError: - return name if isinstance(name, str) else name.name - - name = identifier.name - if not normalize: - return name - - # This can be useful for normalize_identifier - identifier.meta["is_table"] = is_table - return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name + return normalize_name( + name, + dialect=dialect or self.dialect, + is_table=is_table, + normalize=self.normalize if normalize is None else normalize, + ) def depth(self) -> int: if not self.empty and not self._depth: @@ -418,6 +409,26 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): return self._type_mapping_cache[schema_type] +def normalize_name( + name: str | exp.Identifier, + dialect: DialectType = None, + is_table: bool = False, + normalize: t.Optional[bool] = True, +) -> str: + try: + identifier = sqlglot.maybe_parse(name, dialect=dialect, into=exp.Identifier) + except ParseError: + return name if isinstance(name, str) else name.name + + name = identifier.name + if not normalize: + return name + + # This can be useful for normalize_identifier + identifier.meta["is_table"] = is_table + return Dialect.get_or_raise(dialect).normalize_identifier(identifier).name + + def ensure_schema(schema: Schema | t.Optional[t.Dict], **kwargs: t.Any) -> Schema: if isinstance(schema, Schema): return schema diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 81bcc0b..d278dbf 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -110,6 +110,7 @@ class TokenType(AutoName): JSON = auto() JSONB = auto() TIME = auto() + TIMETZ = auto() TIMESTAMP = auto() TIMESTAMPTZ = auto() TIMESTAMPLTZ = auto() @@ -151,6 +152,11 @@ class TokenType(AutoName): IPADDRESS = auto() IPPREFIX = auto() ENUM = auto() + ENUM8 = auto() + ENUM16 = auto() + FIXEDSTRING = auto() + LOWCARDINALITY = auto() + NESTED = auto() # keywords ALIAS = auto() @@ -659,6 +665,7 @@ class Tokenizer(metaclass=_Tokenizer): "TINYINT": TokenType.TINYINT, "SHORT": TokenType.SMALLINT, "SMALLINT": TokenType.SMALLINT, + "INT128": TokenType.INT128, "INT2": TokenType.SMALLINT, "INTEGER": TokenType.INT, "INT": TokenType.INT, @@ -699,6 +706,7 @@ class Tokenizer(metaclass=_Tokenizer): "BYTEA": TokenType.VARBINARY, "VARBINARY": TokenType.VARBINARY, "TIME": TokenType.TIME, + "TIMETZ": TokenType.TIMETZ, "TIMESTAMP": TokenType.TIMESTAMP, "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, "TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ, @@ -879,6 +887,11 @@ class Tokenizer(metaclass=_Tokenizer): def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self._prev_token_line = self._line + + if self._comments and token_type == TokenType.SEMICOLON and self.tokens: + self.tokens[-1].comments.extend(self._comments) + self._comments = [] + self.tokens.append( Token( token_type, |