summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-17 10:32:16 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-17 10:32:16 +0000
commitd3bb537b2b73788ba06bf4158f473ecc5bb556cc (patch)
tree6c1b280de128c7bf77baaa258560a1f39a4e15c7 /sqlglot
parentReleasing debian version 10.4.2-1. (diff)
downloadsqlglot-d3bb537b2b73788ba06bf4158f473ecc5bb556cc.tar.xz
sqlglot-d3bb537b2b73788ba06bf4158f473ecc5bb556cc.zip
Merging upstream version 10.5.2.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py13
-rw-r--r--sqlglot/dialects/bigquery.py7
-rw-r--r--sqlglot/dialects/clickhouse.py35
-rw-r--r--sqlglot/dialects/dialect.py17
-rw-r--r--sqlglot/dialects/hive.py23
-rw-r--r--sqlglot/dialects/oracle.py3
-rw-r--r--sqlglot/dialects/postgres.py21
-rw-r--r--sqlglot/dialects/snowflake.py8
-rw-r--r--sqlglot/dialects/tsql.py22
-rw-r--r--sqlglot/expressions.py117
-rw-r--r--sqlglot/generator.py69
-rw-r--r--sqlglot/helper.py20
-rw-r--r--sqlglot/optimizer/annotate_types.py2
-rw-r--r--sqlglot/optimizer/eliminate_joins.py4
-rw-r--r--sqlglot/optimizer/merge_subqueries.py54
-rw-r--r--sqlglot/optimizer/optimizer.py6
-rw-r--r--sqlglot/optimizer/pushdown_projections.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py4
-rw-r--r--sqlglot/optimizer/simplify.py19
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py38
-rw-r--r--sqlglot/parser.py652
-rw-r--r--sqlglot/schema.py45
-rw-r--r--sqlglot/serde.py67
-rw-r--r--sqlglot/tokens.py19
-rw-r--r--sqlglot/transforms.py24
-rw-r--r--sqlglot/trie.py2
26 files changed, 984 insertions, 311 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 04c3195..87fa081 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.4.2"
+__version__ = "10.5.2"
pretty = False
@@ -60,9 +60,9 @@ def parse(
def parse_one(
sql: str,
read: t.Optional[str | Dialect] = None,
- into: t.Optional[Expression | str] = None,
+ into: t.Optional[t.Type[Expression] | str] = None,
**opts,
-) -> t.Optional[Expression]:
+) -> Expression:
"""
Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
@@ -83,7 +83,12 @@ def parse_one(
else:
result = dialect.parse(sql, **opts)
- return result[0] if result else None
+ for expression in result:
+ if not expression:
+ raise ParseError(f"No expression was parsed from '{sql}'")
+ return expression
+ else:
+ raise ParseError(f"No expression was parsed from '{sql}'")
def transpile(
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index d10cc54..f0089e1 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -2,7 +2,7 @@
from __future__ import annotations
-from sqlglot import exp, generator, parser, tokens
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
datestrtodate_sql,
@@ -46,8 +46,9 @@ def _date_add_sql(data_type, kind):
def _derived_table_values_to_unnest(self, expression):
if not isinstance(expression.unnest().parent, exp.From):
+ expression = transforms.remove_precision_parameterized_types(expression)
return self.values_sql(expression)
- rows = [list(tuple_exp.find_all(exp.Literal)) for tuple_exp in expression.find_all(exp.Tuple)]
+ rows = [tuple_exp.expressions for tuple_exp in expression.find_all(exp.Tuple)]
structs = []
for row in rows:
aliases = [
@@ -118,6 +119,7 @@ class BigQuery(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
+ "DECLARE": TokenType.COMMAND,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"FLOAT64": TokenType.DOUBLE,
"INT64": TokenType.BIGINT,
@@ -166,6 +168,7 @@ class BigQuery(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
+ **transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 7136340..04d46d2 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
from sqlglot.parser import parse_var_map
@@ -22,6 +24,7 @@ class ClickHouse(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ASOF": TokenType.ASOF,
+ "GLOBAL": TokenType.GLOBAL,
"DATETIME64": TokenType.DATETIME,
"FINAL": TokenType.FINAL,
"FLOAT32": TokenType.FLOAT,
@@ -37,14 +40,32 @@ class ClickHouse(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
"MAP": parse_var_map,
+ "QUANTILE": lambda params, args: exp.Quantile(this=args, quantile=params),
+ "QUANTILES": lambda params, args: exp.Quantiles(parameters=params, expressions=args),
+ "QUANTILEIF": lambda params, args: exp.QuantileIf(parameters=params, expressions=args),
+ }
+
+ RANGE_PARSERS = {
+ **parser.Parser.RANGE_PARSERS,
+ TokenType.GLOBAL: lambda self, this: self._match(TokenType.IN)
+ and self._parse_in(this, is_global=True),
}
JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
- def _parse_table(self, schema=False):
- this = super()._parse_table(schema)
+ def _parse_in(
+ self, this: t.Optional[exp.Expression], is_global: bool = False
+ ) -> exp.Expression:
+ this = super()._parse_in(this)
+ this.set("is_global", is_global)
+ return this
+
+ def _parse_table(
+ self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ ) -> t.Optional[exp.Expression]:
+ this = super()._parse_table(schema=schema, alias_tokens=alias_tokens)
if self._match(TokenType.FINAL):
this = self.expression(exp.Final, this=this)
@@ -76,6 +97,16 @@ class ClickHouse(Dialect):
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
exp.Map: lambda self, e: _lower_func(var_map_sql(self, e)),
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
+ exp.Quantile: lambda self, e: f"quantile{self._param_args_sql(e, 'quantile', 'this')}",
+ exp.Quantiles: lambda self, e: f"quantiles{self._param_args_sql(e, 'parameters', 'expressions')}",
+ exp.QuantileIf: lambda self, e: f"quantileIf{self._param_args_sql(e, 'parameters', 'expressions')}",
}
EXPLICIT_UNION = True
+
+ def _param_args_sql(
+ self, expression: exp.Expression, params_name: str, args_name: str
+ ) -> str:
+ params = self.format_args(self.expressions(expression, params_name))
+ args = self.format_args(self.expressions(expression, args_name))
+ return f"({params})({args})"
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index e788852..1c840da 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -381,3 +381,20 @@ def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str:
def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str:
return f"CAST({self.sql(expression, 'this')} AS DATE)"
+
+
+def trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+ remove_chars = self.sql(expression, "expression")
+ collation = self.sql(expression, "collation")
+
+ # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific
+ if not remove_chars and not collation:
+ return self.trim_sql(expression)
+
+ trim_type = f"{trim_type} " if trim_type else ""
+ remove_chars = f"{remove_chars} " if remove_chars else ""
+ from_part = "FROM " if trim_type or remove_chars else ""
+ collation = f" COLLATE {collation}" if collation else ""
+ return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 088555c..ead13b1 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -175,14 +175,6 @@ class Hive(Dialect):
ESCAPES = ["\\"]
ENCODE = "utf-8"
- NUMERIC_LITERALS = {
- "L": "BIGINT",
- "S": "SMALLINT",
- "Y": "TINYINT",
- "D": "DOUBLE",
- "F": "FLOAT",
- "BD": "DECIMAL",
- }
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ADD ARCHIVE": TokenType.COMMAND,
@@ -191,9 +183,21 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
+ "MSCK REPAIR": TokenType.COMMAND,
"WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
+ NUMERIC_LITERALS = {
+ "L": "BIGINT",
+ "S": "SMALLINT",
+ "Y": "TINYINT",
+ "D": "DOUBLE",
+ "F": "FLOAT",
+ "BD": "DECIMAL",
+ }
+
+ IDENTIFIER_CAN_START_WITH_DIGIT = True
+
class Parser(parser.Parser):
STRICT_CAST = False
@@ -315,6 +319,7 @@ class Hive(Dialect):
exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
+ exp.LastDateOfMonth: rename_func("LAST_DAY"),
}
WITH_PROPERTIES = {exp.Property}
@@ -342,4 +347,6 @@ class Hive(Dialect):
and not expression.expressions
):
expression = exp.DataType.build("text")
+ elif expression.this in exp.DataType.TEMPORAL_TYPES:
+ expression = exp.DataType.build(expression.this)
return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index af3d353..86caa6b 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -1,7 +1,7 @@
from __future__ import annotations
from sqlglot import exp, generator, parser, tokens, transforms
-from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func
+from sqlglot.dialects.dialect import Dialect, no_ilike_sql, rename_func, trim_sql
from sqlglot.helper import csv
from sqlglot.tokens import TokenType
@@ -64,6 +64,7 @@ class Oracle(Dialect):
**transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
+ exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index a092cad..f3fec31 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
str_position_sql,
+ trim_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -81,23 +82,6 @@ def _substring_sql(self, expression):
return f"SUBSTRING({this}{from_part}{for_part})"
-def _trim_sql(self, expression):
- target = self.sql(expression, "this")
- trim_type = self.sql(expression, "position")
- remove_chars = self.sql(expression, "expression")
- collation = self.sql(expression, "collation")
-
- # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
- if not remove_chars and not collation:
- return self.trim_sql(expression)
-
- trim_type = f"{trim_type} " if trim_type else ""
- remove_chars = f"{remove_chars} " if remove_chars else ""
- from_part = "FROM " if trim_type or remove_chars else ""
- collation = f" COLLATE {collation}" if collation else ""
- return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
-
-
def _string_agg_sql(self, expression):
expression = expression.copy()
separator = expression.args.get("separator") or exp.Literal.string(",")
@@ -248,7 +232,6 @@ class Postgres(Dialect):
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
- "DOUBLE PRECISION": TokenType.DOUBLE,
"GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
@@ -318,7 +301,7 @@ class Postgres(Dialect):
exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
- exp.Trim: _trim_sql,
+ exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 77b09e9..24d3bdf 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -195,7 +195,6 @@ class Snowflake(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
- "DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
@@ -294,3 +293,10 @@ class Snowflake(Dialect):
)
return self.no_identify(lambda: super(self.__class__, self).select_sql(expression))
return super().select_sql(expression)
+
+ def describe_sql(self, expression: exp.Describe) -> str:
+ # Default to table if kind is unknown
+ kind_value = expression.args.get("kind") or "TABLE"
+ kind = f" {kind_value}" if kind_value else ""
+ this = f" {self.sql(expression, 'this')}"
+ return f"DESCRIBE{kind}{this}"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 7f0f2d7..465f534 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -75,6 +75,20 @@ def _parse_format(args):
)
+def _parse_eomonth(args):
+ date = seq_get(args, 0)
+ month_lag = seq_get(args, 1)
+ unit = DATE_DELTA_INTERVAL.get("month")
+
+ if month_lag is None:
+ return exp.LastDateOfMonth(this=date)
+
+ # Remove month lag argument in parser as its compared with the number of arguments of the resulting class
+ args.remove(month_lag)
+
+ return exp.LastDateOfMonth(this=exp.DateAdd(this=date, expression=month_lag, unit=unit))
+
+
def generate_date_delta_with_unit_sql(self, e):
func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF"
return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})"
@@ -256,12 +270,14 @@ class TSQL(Dialect):
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
"DATENAME": _format_time_lambda(exp.TimeToStr, full_format_mapping=True),
"DATEPART": _format_time_lambda(exp.TimeToStr),
- "GETDATE": exp.CurrentDate.from_arg_list,
+ "GETDATE": exp.CurrentTimestamp.from_arg_list,
+ "SYSDATETIME": exp.CurrentTimestamp.from_arg_list,
"IIF": exp.If.from_arg_list,
"LEN": exp.Length.from_arg_list,
"REPLICATE": exp.Repeat.from_arg_list,
"JSON_VALUE": exp.JSONExtractScalar.from_arg_list,
"FORMAT": _parse_format,
+ "EOMONTH": _parse_eomonth,
}
VAR_LENGTH_DATATYPES = {
@@ -271,6 +287,9 @@ class TSQL(Dialect):
DataType.Type.NCHAR,
}
+ # https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
+ TABLE_PREFIX_TOKENS = {TokenType.HASH}
+
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
@@ -323,6 +342,7 @@ class TSQL(Dialect):
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
+ exp.CurrentTimestamp: rename_func("GETDATE"),
exp.If: rename_func("IIF"),
exp.NumberToStr: _format_sql,
exp.TimeToStr: _format_sql,
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 711ec4b..d093e29 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -22,6 +22,7 @@ from sqlglot.helper import (
split_num_words,
subclasses,
)
+from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import Dialect
@@ -457,6 +458,23 @@ class Expression(metaclass=_Expression):
assert isinstance(self, type_)
return self
+ def dump(self):
+ """
+ Dump this Expression to a JSON-serializable dict.
+ """
+ from sqlglot.serde import dump
+
+ return dump(self)
+
+ @classmethod
+ def load(cls, obj):
+ """
+ Load a dict (as returned by `Expression.dump`) into an Expression instance.
+ """
+ from sqlglot.serde import load
+
+ return load(obj)
+
class Condition(Expression):
def and_(self, *expressions, dialect=None, **opts):
@@ -631,11 +649,15 @@ class Create(Expression):
"replace": False,
"unique": False,
"materialized": False,
+ "data": False,
+ "statistics": False,
+ "no_primary_index": False,
+ "indexes": False,
}
class Describe(Expression):
- pass
+ arg_types = {"this": True, "kind": False}
class Set(Expression):
@@ -731,7 +753,7 @@ class Column(Condition):
class ColumnDef(Expression):
arg_types = {
"this": True,
- "kind": True,
+ "kind": False,
"constraints": False,
"exists": False,
}
@@ -879,7 +901,15 @@ class Identifier(Expression):
class Index(Expression):
- arg_types = {"this": False, "table": False, "where": False, "columns": False}
+ arg_types = {
+ "this": False,
+ "table": False,
+ "where": False,
+ "columns": False,
+ "unique": False,
+ "primary": False,
+ "amp": False, # teradata
+ }
class Insert(Expression):
@@ -1361,6 +1391,7 @@ class Table(Expression):
"laterals": False,
"joins": False,
"pivots": False,
+ "hints": False,
}
@@ -1818,7 +1849,12 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
+ natural: t.Optional[Token]
+ side: t.Optional[Token]
+ kind: t.Optional[Token]
+
natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
+
if natural:
join.set("natural", True)
if side:
@@ -2111,6 +2147,7 @@ class DataType(Expression):
JSON = auto()
JSONB = auto()
INTERVAL = auto()
+ TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@@ -2171,11 +2208,24 @@ class DataType(Expression):
}
@classmethod
- def build(cls, dtype, **kwargs) -> DataType:
- return DataType(
- this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
- **kwargs,
- )
+ def build(
+ cls, dtype: str | DataType.Type, dialect: t.Optional[str | Dialect] = None, **kwargs
+ ) -> DataType:
+ from sqlglot import parse_one
+
+ if isinstance(dtype, str):
+ data_type_exp: t.Optional[Expression]
+ if dtype.upper() in cls.Type.__members__:
+ data_type_exp = DataType(this=DataType.Type[dtype.upper()])
+ else:
+ data_type_exp = parse_one(dtype, read=dialect, into=DataType)
+ if data_type_exp is None:
+ raise ValueError(f"Unparsable data type value: {dtype}")
+ elif isinstance(dtype, DataType.Type):
+ data_type_exp = DataType(this=dtype)
+ else:
+ raise ValueError(f"Invalid data type: {type(dtype)}. Expected str or DataType.Type")
+ return DataType(**{**data_type_exp.args, **kwargs})
# https://www.postgresql.org/docs/15/datatype-pseudo.html
@@ -2429,6 +2479,7 @@ class In(Predicate):
"query": False,
"unnest": False,
"field": False,
+ "is_global": False,
}
@@ -2678,6 +2729,10 @@ class DatetimeTrunc(Func, TimeUnit):
arg_types = {"this": True, "unit": True, "zone": False}
+class LastDateOfMonth(Func):
+ pass
+
+
class Extract(Func):
arg_types = {"this": True, "expression": True}
@@ -2815,7 +2870,13 @@ class Length(Func):
class Levenshtein(Func):
- arg_types = {"this": True, "expression": False}
+ arg_types = {
+ "this": True,
+ "expression": False,
+ "ins_cost": False,
+ "del_cost": False,
+ "sub_cost": False,
+ }
class Ln(Func):
@@ -2890,6 +2951,16 @@ class Quantile(AggFunc):
arg_types = {"this": True, "quantile": True}
+# Clickhouse-specific:
+# https://clickhouse.com/docs/en/sql-reference/aggregate-functions/reference/quantiles/#quantiles
+class Quantiles(AggFunc):
+ arg_types = {"parameters": True, "expressions": True}
+
+
+class QuantileIf(AggFunc):
+ arg_types = {"parameters": True, "expressions": True}
+
+
class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
@@ -2962,8 +3033,10 @@ class StrToTime(Func):
arg_types = {"this": True, "format": True}
+# Spark allows unix_timestamp()
+# https://spark.apache.org/docs/3.1.3/api/python/reference/api/pyspark.sql.functions.unix_timestamp.html
class StrToUnix(Func):
- arg_types = {"this": True, "format": True}
+ arg_types = {"this": False, "format": False}
class NumberToStr(Func):
@@ -3131,7 +3204,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
-) -> t.Optional[Expression]:
+) -> Expression:
"""Gracefully handle a possible string or expression.
Example:
@@ -3627,11 +3700,11 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")
- catalog, db, table_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 3)]
+ catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
-def to_column(sql_path: str, **kwargs) -> Column:
+def to_column(sql_path: str | Column, **kwargs) -> Column:
"""
Create a column from a `[table].[column]` sql path. Schema is optional.
@@ -3646,7 +3719,7 @@ def to_column(sql_path: str, **kwargs) -> Column:
return sql_path
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for column: {type(sql_path)}")
- table_name, column_name = [to_identifier(x) for x in split_num_words(sql_path, ".", 2)]
+ table_name, column_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 2))
return Column(this=column_name, table=table_name, **kwargs)
@@ -3748,7 +3821,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
def values(
values: t.Iterable[t.Tuple[t.Any, ...]],
alias: t.Optional[str] = None,
- columns: t.Optional[t.Iterable[str]] = None,
+ columns: t.Optional[t.Iterable[str] | t.Dict[str, DataType]] = None,
) -> Values:
"""Build VALUES statement.
@@ -3759,7 +3832,10 @@ def values(
Args:
values: values statements that will be converted to SQL
alias: optional alias
- columns: Optional list of ordered column names. An alias is required when providing column names.
+ columns: Optional list of ordered column names or ordered dictionary of column names to types.
+ If either are provided then an alias is also required.
+ If a dictionary is provided then the first column of the values will be casted to the expected type
+ in order to help with type inference.
Returns:
Values: the Values expression object
@@ -3771,8 +3847,15 @@ def values(
if columns
else TableAlias(this=to_identifier(alias) if alias else None)
)
+ expressions = [convert(tup) for tup in values]
+ if columns and isinstance(columns, dict):
+ types = list(columns.values())
+ expressions[0].set(
+ "expressions",
+ [Cast(this=x, to=types[i]) for i, x in enumerate(expressions[0].expressions)],
+ )
return Values(
- expressions=[convert(tup) for tup in values],
+ expressions=expressions,
alias=table_alias,
)
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 0c1578a..3935133 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -50,7 +50,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
- comments: Whether or not to preserve comments in the ouput SQL code.
+ comments: Whether or not to preserve comments in the output SQL code.
Default: True
"""
@@ -236,7 +236,10 @@ class Generator:
return sql
sep = "\n" if self.pretty else " "
- comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
+ comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
+
+ if not comments:
+ return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
return f"{comments}{self.sep()}{sql}"
@@ -362,10 +365,10 @@ class Generator:
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
exists = "IF NOT EXISTS " if expression.args.get("exists") else ""
+ kind = f" {kind}" if kind else ""
+ constraints = f" {constraints}" if constraints else ""
- if not constraints:
- return f"{exists}{column} {kind}"
- return f"{exists}{column} {kind} {constraints}"
+ return f"{exists}{column}{kind}{constraints}"
def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
@@ -416,7 +419,7 @@ class Generator:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
expression_sql = self.sql(expression, "expression")
- expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
+ expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@@ -427,6 +430,40 @@ class Generator:
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
+ data = expression.args.get("data")
+ if data is None:
+ data = ""
+ elif data:
+ data = " WITH DATA"
+ else:
+ data = " WITH NO DATA"
+ statistics = expression.args.get("statistics")
+ if statistics is None:
+ statistics = ""
+ elif statistics:
+ statistics = " AND STATISTICS"
+ else:
+ statistics = " AND NO STATISTICS"
+ no_primary_index = " NO PRIMARY INDEX" if expression.args.get("no_primary_index") else ""
+
+ indexes = expression.args.get("indexes")
+ index_sql = ""
+ if indexes is not None:
+ indexes_sql = []
+ for index in indexes:
+ ind_unique = " UNIQUE" if index.args.get("unique") else ""
+ ind_primary = " PRIMARY" if index.args.get("primary") else ""
+ ind_amp = " AMP" if index.args.get("amp") else ""
+ ind_name = f" {index.name}" if index.name else ""
+ ind_columns = (
+ f' ({self.expressions(index, key="columns", flat=True)})'
+ if index.args.get("columns")
+ else ""
+ )
+ indexes_sql.append(
+ f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
+ )
+ index_sql = "".join(indexes_sql)
modifiers = "".join(
(
@@ -438,7 +475,10 @@ class Generator:
materialized,
)
)
- expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
+
+ post_expression_modifiers = "".join((data, statistics, no_primary_index))
+
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@@ -668,6 +708,8 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
+ hints = self.expressions(expression, key="hints", sep=", ", flat=True)
+ hints = f" WITH ({hints})" if hints else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
@@ -676,7 +718,7 @@ class Generator:
pivots = f"{pivots}{alias}"
alias = ""
- return f"{table}{alias}{laterals}{joins}{pivots}"
+ return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
@@ -1020,7 +1062,9 @@ class Generator:
if not partition and not order and not spec and alias:
return f"{this} {alias}"
- return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
+ window_args = alias + partition_sql + order_sql + spec_sql
+
+ return f"{this} ({window_args.strip()})"
def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
@@ -1130,6 +1174,8 @@ class Generator:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
+ is_global = " GLOBAL" if expression.args.get("is_global") else ""
+
if query:
in_sql = self.wrap(query)
elif unnest:
@@ -1138,7 +1184,8 @@ class Generator:
in_sql = self.sql(field)
else:
in_sql = f"({self.expressions(expression, flat=True)})"
- return f"{self.sql(expression, 'this')} IN {in_sql}"
+
+ return f"{self.sql(expression, 'this')}{is_global} IN {in_sql}"
def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})"
@@ -1433,7 +1480,7 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
- comments = self.maybe_comment("", e)
+ comments = self.maybe_comment("", e) if isinstance(e, exp.Expression) else ""
if self.pretty:
if self._leading_comma:
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index ed37e6c..5a0f2ac 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -131,7 +131,7 @@ def subclasses(
]
-def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
+def apply_index_offset(expressions: t.List[t.Optional[E]], offset: int) -> t.List[t.Optional[E]]:
"""
Applies an offset to a given integer literal expression.
@@ -148,10 +148,10 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
expression = expressions[0]
- if expression.is_int:
+ if expression and expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
- expression.args["this"] = str(int(expression.this) + offset)
+ expression.args["this"] = str(int(expression.this) + offset) # type: ignore
return [expression]
return expressions
@@ -225,7 +225,7 @@ def open_file(file_name: str) -> t.TextIO:
return gzip.open(file_name, "rt", newline="")
- return open(file_name, "rt", encoding="utf-8", newline="")
+ return open(file_name, encoding="utf-8", newline="")
@contextmanager
@@ -256,7 +256,7 @@ def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
file.close()
-def find_new_name(taken: t.Sequence[str], base: str) -> str:
+def find_new_name(taken: t.Collection[str], base: str) -> str:
"""
Searches for a new name.
@@ -356,6 +356,15 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield value
+def count_params(function: t.Callable) -> int:
+ """
+ Returns the number of formal parameters expected by a function, without counting "self"
+ and "cls", in case of instance and class methods, respectively.
+ """
+ count = function.__code__.co_argcount
+ return count - 1 if inspect.ismethod(function) else count
+
+
def dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
@@ -374,6 +383,7 @@ def dict_depth(d: t.Dict) -> int:
Args:
d (dict): dictionary
+
Returns:
int: depth
"""
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index be17f15..bfb2bb8 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -43,7 +43,7 @@ class TypeAnnotator:
},
exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
+ exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.copy()),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 3b40710..8e6a520 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -57,7 +57,7 @@ def _join_is_used(scope, join, alias):
# But columns in the ON clause shouldn't count.
on = join.args.get("on")
if on:
- on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
+ on_clause_columns = {id(column) for column in on.find_all(exp.Column)}
else:
on_clause_columns = set()
return any(
@@ -71,7 +71,7 @@ def _is_joined_on_all_unique_outputs(scope, join):
return False
_, join_keys, _ = join_condition(join)
- remaining_unique_outputs = unique_outputs - set(c.name for c in join_keys)
+ remaining_unique_outputs = unique_outputs - {c.name for c in join_keys}
return not remaining_unique_outputs
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 9ae4966..16aaf17 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -67,11 +67,9 @@ def merge_ctes(expression, leave_tables_isolated=False):
singular_cte_selections = [v[0] for k, v in cte_selections.items() if len(v) == 1]
for outer_scope, inner_scope, table in singular_cte_selections:
- inner_select = inner_scope.expression.unnest()
from_or_join = table.find_ancestor(exp.From, exp.Join)
- if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
alias = table.alias_or_name
-
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, table, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@@ -80,18 +78,17 @@ def merge_ctes(expression, leave_tables_isolated=False):
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
_pop_cte(inner_scope)
+ outer_scope.clear_cache()
return expression
def merge_derived_tables(expression, leave_tables_isolated=False):
for outer_scope in traverse_scope(expression):
for subquery in outer_scope.derived_tables:
- inner_select = subquery.unnest()
from_or_join = subquery.find_ancestor(exp.From, exp.Join)
- if _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
- alias = subquery.alias_or_name
- inner_scope = outer_scope.sources[alias]
-
+ alias = subquery.alias_or_name
+ inner_scope = outer_scope.sources[alias]
+ if _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
_rename_inner_sources(outer_scope, inner_scope, alias)
_merge_from(outer_scope, inner_scope, subquery, alias)
_merge_expressions(outer_scope, inner_scope, alias)
@@ -99,21 +96,23 @@ def merge_derived_tables(expression, leave_tables_isolated=False):
_merge_where(outer_scope, inner_scope, from_or_join)
_merge_order(outer_scope, inner_scope)
_merge_hints(outer_scope, inner_scope)
+ outer_scope.clear_cache()
return expression
-def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
+def _mergeable(outer_scope, inner_scope, leave_tables_isolated, from_or_join):
"""
Return True if `inner_select` can be merged into outer query.
Args:
outer_scope (Scope)
- inner_select (exp.Select)
+ inner_scope (Scope)
leave_tables_isolated (bool)
from_or_join (exp.From|exp.Join)
Returns:
bool: True if can be merged
"""
+ inner_select = inner_scope.expression.unnest()
def _is_a_window_expression_in_unmergable_operation():
window_expressions = inner_select.find_all(exp.Window)
@@ -133,10 +132,40 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
]
return any(window_expressions_in_unmergable)
+ def _outer_select_joins_on_inner_select_join():
+ """
+ All columns from the inner select in the ON clause must be from the first FROM table.
+
+ That is, this can be merged:
+ SELECT * FROM x JOIN (SELECT y.a AS a FROM y JOIN z) AS q ON x.a = q.a
+ ^^^ ^
+ But this can't:
+ SELECT * FROM x JOIN (SELECT z.a AS a FROM y JOIN z) AS q ON x.a = q.a
+ ^^^ ^
+ """
+ if not isinstance(from_or_join, exp.Join):
+ return False
+
+ alias = from_or_join.this.alias_or_name
+
+ on = from_or_join.args.get("on")
+ if not on:
+ return False
+ selections = [c.name for c in on.find_all(exp.Column) if c.table == alias]
+ inner_from = inner_scope.expression.args.get("from")
+ if not inner_from:
+ return False
+ inner_from_table = inner_from.expressions[0].alias_or_name
+ inner_projections = {s.alias_or_name: s for s in inner_scope.selects}
+ return any(
+ col.table != inner_from_table
+ for selection in selections
+ for col in inner_projections[selection].find_all(exp.Column)
+ )
+
return (
isinstance(outer_scope.expression, exp.Select)
and isinstance(inner_select, exp.Select)
- and isinstance(inner_select, exp.Select)
and not any(inner_select.args.get(arg) for arg in UNMERGABLE_ARGS)
and inner_select.args.get("from")
and not any(e.find(exp.AggFunc, exp.Select) for e in inner_select.expressions)
@@ -153,6 +182,7 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
)
)
+ and not _outer_select_joins_on_inner_select_join()
and not _is_a_window_expression_in_unmergable_operation()
)
@@ -168,7 +198,7 @@ def _rename_inner_sources(outer_scope, inner_scope, alias):
"""
taken = set(outer_scope.selected_sources)
conflicts = taken.intersection(set(inner_scope.selected_sources))
- conflicts = conflicts - {alias}
+ conflicts -= {alias}
for conflict in conflicts:
new_name = find_new_name(taken, conflict)
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 72e67d4..46b6b30 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -15,6 +15,7 @@ from sqlglot.optimizer.pushdown_projections import pushdown_projections
from sqlglot.optimizer.qualify_columns import qualify_columns
from sqlglot.optimizer.qualify_tables import qualify_tables
from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
+from sqlglot.schema import ensure_schema
RULES = (
lower_identities,
@@ -51,12 +52,13 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
- rules (list): sequence of optimizer rules to use
+ rules (sequence): sequence of optimizer rules to use
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
"""
- possible_kwargs = {"db": db, "catalog": catalog, "schema": schema or sqlglot.schema, **kwargs}
+ schema = ensure_schema(schema or sqlglot.schema)
+ possible_kwargs = {"db": db, "catalog": catalog, "schema": schema, **kwargs}
expression = expression.copy()
for rule in rules:
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 49789ac..a73647c 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -79,6 +79,7 @@ def _remove_unused_selections(scope, parent_selections):
order_refs = set()
new_selections = []
+ removed = False
for i, selection in enumerate(scope.selects):
if (
SELECT_ALL in parent_selections
@@ -88,12 +89,15 @@ def _remove_unused_selections(scope, parent_selections):
new_selections.append(selection)
else:
removed_indexes.append(i)
+ removed = True
# If there are no remaining selections, just select a single constant
if not new_selections:
new_selections.append(DEFAULT_SELECTION.copy())
scope.expression.set("expressions", new_selections)
+ if removed:
+ scope.clear_cache()
return removed_indexes
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index e16a635..f4568c2 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -365,9 +365,9 @@ class _Resolver:
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
- self._all_columns = set(
+ self._all_columns = {
column for columns in self._get_all_source_columns().values() for column in columns
- )
+ }
return self._all_columns
def get_source_columns(self, name, only_visible=False):
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c0719f2..f560760 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -361,7 +361,7 @@ def _simplify_binary(expression, a, b):
return boolean
elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
a, b = extract_date(a), extract_interval(b)
- if b:
+ if a and b:
if isinstance(expression, exp.Add):
return date_literal(a + b)
if isinstance(expression, exp.Sub):
@@ -369,7 +369,7 @@ def _simplify_binary(expression, a, b):
elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
a, b = extract_interval(a), extract_date(b)
# you cannot subtract a date from an interval
- if a and isinstance(expression, exp.Add):
+ if a and b and isinstance(expression, exp.Add):
return date_literal(a + b)
return None
@@ -424,9 +424,15 @@ def eval_boolean(expression, a, b):
def extract_date(cast):
- if cast.args["to"].this == exp.DataType.Type.DATE:
- return datetime.date.fromisoformat(cast.name)
- return None
+ # The "fromisoformat" conversion could fail if the cast is used on an identifier,
+ # so in that case we can't extract the date.
+ try:
+ if cast.args["to"].this == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(cast.name)
+ if cast.args["to"].this == exp.DataType.Type.DATETIME:
+ return datetime.datetime.fromisoformat(cast.name)
+ except ValueError:
+ return None
def extract_interval(interval):
@@ -450,7 +456,8 @@ def extract_interval(interval):
def date_literal(date):
- return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
+ expr_type = exp.DataType.build("DATETIME" if isinstance(date, datetime.datetime) else "DATE")
+ return exp.Cast(this=exp.Literal.string(date), to=expr_type)
def boolean_literal(condition):
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 8d78294..a515489 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -15,8 +15,7 @@ def unnest_subqueries(expression):
>>> import sqlglot
>>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
>>> unnest_subqueries(expression).sql()
- 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
- AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)'
+ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a) AS _u_0 ON x.a = _u_0.a WHERE _u_0.a = 1'
Args:
expression (sqlglot.Expression): expression to unnest
@@ -173,10 +172,8 @@ def decorrelate(select, parent_select, external_columns, sequence):
other = _other_operand(parent_predicate)
if isinstance(parent_predicate, exp.Exists):
- if value.this in group_by:
- parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
- else:
- parent_predicate = _replace(parent_predicate, "TRUE")
+ alias = exp.column(list(key_aliases.values())[0], table_alias)
+ parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
elif isinstance(parent_predicate, exp.All):
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
@@ -197,6 +194,23 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
if is_subquery_projection:
alias = exp.alias_(alias, select.parent.alias)
+
+ # COUNT always returns 0 on empty datasets, so we need take that into consideration here
+ # by transforming all counts into 0 and using that as the coalesced value
+ if value.find(exp.Count):
+
+ def remove_aggs(node):
+ if isinstance(node, exp.Count):
+ return exp.Literal.number(0)
+ elif isinstance(node, exp.AggFunc):
+ return exp.null()
+ return node
+
+ alias = exp.Coalesce(
+ this=alias,
+ expressions=[value.this.transform(remove_aggs)],
+ )
+
select.parent.replace(alias)
for key, column, predicate in keys:
@@ -209,9 +223,6 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(
- parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
- )
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
@@ -245,7 +256,14 @@ def _other_operand(expression):
if isinstance(expression, exp.In):
return expression.this
+ if isinstance(expression, (exp.Any, exp.All)):
+ return _other_operand(expression.parent)
+
if isinstance(expression, exp.Binary):
- return expression.right if expression.arg_key == "this" else expression.left
+ return (
+ expression.right
+ if isinstance(expression.left, (exp.Subquery, exp.Any, exp.Exists, exp.All))
+ else expression.left
+ )
return None
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 308f363..bd95db8 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -5,7 +5,13 @@ import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
-from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get
+from sqlglot.helper import (
+ apply_index_offset,
+ count_params,
+ ensure_collection,
+ ensure_list,
+ seq_get,
+)
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import in_trie, new_trie
@@ -54,7 +60,7 @@ class Parser(metaclass=_Parser):
Default: "nulls_are_small"
"""
- FUNCTIONS = {
+ FUNCTIONS: t.Dict[str, t.Callable] = {
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
"DATE_TO_DATE_STR": lambda args: exp.Cast(
this=seq_get(args, 0),
@@ -106,6 +112,7 @@ class Parser(metaclass=_Parser):
TokenType.JSON,
TokenType.JSONB,
TokenType.INTERVAL,
+ TokenType.TIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
@@ -164,6 +171,7 @@ class Parser(metaclass=_Parser):
TokenType.DELETE,
TokenType.DESCRIBE,
TokenType.DETERMINISTIC,
+ TokenType.DIV,
TokenType.DISTKEY,
TokenType.DISTSTYLE,
TokenType.EXECUTE,
@@ -252,6 +260,7 @@ class Parser(metaclass=_Parser):
TokenType.FIRST,
TokenType.FORMAT,
TokenType.IDENTIFIER,
+ TokenType.INDEX,
TokenType.ISNULL,
TokenType.MERGE,
TokenType.OFFSET,
@@ -312,6 +321,7 @@ class Parser(metaclass=_Parser):
}
TIMESTAMPS = {
+ TokenType.TIME,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
TokenType.TIMESTAMPLTZ,
@@ -387,6 +397,7 @@ class Parser(metaclass=_Parser):
}
EXPRESSION_PARSERS = {
+ exp.Column: lambda self: self._parse_column(),
exp.DataType: lambda self: self._parse_types(),
exp.From: lambda self: self._parse_from(),
exp.Group: lambda self: self._parse_group(),
@@ -419,6 +430,7 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
+ TokenType.DESC: lambda self: self._parse_describe(),
TokenType.DESCRIBE: lambda self: self._parse_describe(),
TokenType.DROP: lambda self: self._parse_drop(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
@@ -583,6 +595,11 @@ class Parser(metaclass=_Parser):
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+ WINDOW_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.ROWS}
+
+ # allows tables to have special tokens as prefixes
+ TABLE_PREFIX_TOKENS: t.Set[TokenType] = set()
+
STRICT_CAST = True
__slots__ = (
@@ -608,13 +625,13 @@ class Parser(metaclass=_Parser):
def __init__(
self,
- error_level=None,
- error_message_context=100,
- index_offset=0,
- unnest_column_only=False,
- alias_post_tablesample=False,
- max_errors=3,
- null_ordering=None,
+ error_level: t.Optional[ErrorLevel] = None,
+ error_message_context: int = 100,
+ index_offset: int = 0,
+ unnest_column_only: bool = False,
+ alias_post_tablesample: bool = False,
+ max_errors: int = 3,
+ null_ordering: t.Optional[str] = None,
):
self.error_level = error_level or ErrorLevel.IMMEDIATE
self.error_message_context = error_message_context
@@ -636,23 +653,43 @@ class Parser(metaclass=_Parser):
self._prev = None
self._prev_comments = None
- def parse(self, raw_tokens, sql=None):
+ def parse(
+ self, raw_tokens: t.List[Token], sql: t.Optional[str] = None
+ ) -> t.List[t.Optional[exp.Expression]]:
"""
- Parses the given list of tokens and returns a list of syntax trees, one tree
+ Parses a list of tokens and returns a list of syntax trees, one tree
per parsed SQL statement.
- Args
- raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`).
- sql (str): the original SQL string. Used to produce helpful debug messages.
+ Args:
+ raw_tokens: the list of tokens.
+ sql: the original SQL string, used to produce helpful debug messages.
- Returns
- the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
+ Returns:
+ The list of syntax trees.
"""
return self._parse(
parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
)
- def parse_into(self, expression_types, raw_tokens, sql=None):
+ def parse_into(
+ self,
+ expression_types: str | exp.Expression | t.Collection[exp.Expression | str],
+ raw_tokens: t.List[Token],
+ sql: t.Optional[str] = None,
+ ) -> t.List[t.Optional[exp.Expression]]:
+ """
+ Parses a list of tokens into a given Expression type. If a collection of Expression
+ types is given instead, this method will try to parse the token list into each one
+ of them, stopping at the first for which the parsing succeeds.
+
+ Args:
+ expression_types: the expression type(s) to try and parse the token list into.
+ raw_tokens: the list of tokens.
+ sql: the original SQL string, used to produce helpful debug messages.
+
+ Returns:
+ The target Expression.
+ """
errors = []
for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
@@ -668,7 +705,12 @@ class Parser(metaclass=_Parser):
errors=merge_errors(errors),
) from errors[-1]
- def _parse(self, parse_method, raw_tokens, sql=None):
+ def _parse(
+ self,
+ parse_method: t.Callable[[Parser], t.Optional[exp.Expression]],
+ raw_tokens: t.List[Token],
+ sql: t.Optional[str] = None,
+ ) -> t.List[t.Optional[exp.Expression]]:
self.reset()
self.sql = sql or ""
total = len(raw_tokens)
@@ -686,6 +728,7 @@ class Parser(metaclass=_Parser):
self._index = -1
self._tokens = tokens
self._advance()
+
expressions.append(parse_method(self))
if self._index < len(self._tokens):
@@ -695,7 +738,10 @@ class Parser(metaclass=_Parser):
return expressions
- def check_errors(self):
+ def check_errors(self) -> None:
+ """
+ Logs or raises any found errors, depending on the chosen error level setting.
+ """
if self.error_level == ErrorLevel.WARN:
for error in self.errors:
logger.error(str(error))
@@ -705,13 +751,18 @@ class Parser(metaclass=_Parser):
errors=merge_errors(self.errors),
)
- def raise_error(self, message, token=None):
+ def raise_error(self, message: str, token: t.Optional[Token] = None) -> None:
+ """
+ Appends an error in the list of recorded errors or raises it, depending on the chosen
+ error level setting.
+ """
token = token or self._curr or self._prev or Token.string("")
start = self._find_token(token, self.sql)
end = start + len(token.text)
start_context = self.sql[max(start - self.error_message_context, 0) : start]
highlight = self.sql[start:end]
end_context = self.sql[end : end + self.error_message_context]
+
error = ParseError.new(
f"{message}. Line {token.line}, Col: {token.col}.\n"
f" {start_context}\033[4m{highlight}\033[0m{end_context}",
@@ -722,11 +773,26 @@ class Parser(metaclass=_Parser):
highlight=highlight,
end_context=end_context,
)
+
if self.error_level == ErrorLevel.IMMEDIATE:
raise error
+
self.errors.append(error)
- def expression(self, exp_class, comments=None, **kwargs):
+ def expression(
+ self, exp_class: t.Type[exp.Expression], comments: t.Optional[t.List[str]] = None, **kwargs
+ ) -> exp.Expression:
+ """
+ Creates a new, validated Expression.
+
+ Args:
+ exp_class: the expression class to instantiate.
+ comments: an optional list of comments to attach to the expression.
+ kwargs: the arguments to set for the expression along with their respective values.
+
+ Returns:
+ The target expression.
+ """
instance = exp_class(**kwargs)
if self._prev_comments:
instance.comments = self._prev_comments
@@ -736,7 +802,17 @@ class Parser(metaclass=_Parser):
self.validate_expression(instance)
return instance
- def validate_expression(self, expression, args=None):
+ def validate_expression(
+ self, expression: exp.Expression, args: t.Optional[t.List] = None
+ ) -> None:
+ """
+ Validates an already instantiated expression, making sure that all its mandatory arguments
+ are set.
+
+ Args:
+ expression: the expression to validate.
+ args: an optional list of items that was used to instantiate the expression, if it's a Func.
+ """
if self.error_level == ErrorLevel.IGNORE:
return
@@ -748,13 +824,18 @@ class Parser(metaclass=_Parser):
if mandatory and (v is None or (isinstance(v, list) and not v)):
self.raise_error(f"Required keyword: '{k}' missing for {expression.__class__}")
- if args and len(args) > len(expression.arg_types) and not expression.is_var_len_args:
+ if (
+ args
+ and isinstance(expression, exp.Func)
+ and len(args) > len(expression.arg_types)
+ and not expression.is_var_len_args
+ ):
self.raise_error(
f"The number of provided arguments ({len(args)}) is greater than "
f"the maximum number of supported arguments ({len(expression.arg_types)})"
)
- def _find_token(self, token, sql):
+ def _find_token(self, token: Token, sql: str) -> int:
line = 1
col = 1
index = 0
@@ -769,7 +850,7 @@ class Parser(metaclass=_Parser):
return index
- def _advance(self, times=1):
+ def _advance(self, times: int = 1) -> None:
self._index += times
self._curr = seq_get(self._tokens, self._index)
self._next = seq_get(self._tokens, self._index + 1)
@@ -780,10 +861,10 @@ class Parser(metaclass=_Parser):
self._prev = None
self._prev_comments = None
- def _retreat(self, index):
+ def _retreat(self, index: int) -> None:
self._advance(index - self._index)
- def _parse_statement(self):
+ def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -803,7 +884,7 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(expression)
return expression
- def _parse_drop(self, default_kind=None):
+ def _parse_drop(self, default_kind: t.Optional[str] = None) -> t.Optional[exp.Expression]:
temporary = self._match(TokenType.TEMPORARY)
materialized = self._match(TokenType.MATERIALIZED)
kind = self._match_set(self.CREATABLES) and self._prev.text
@@ -812,7 +893,7 @@ class Parser(metaclass=_Parser):
kind = default_kind
else:
self.raise_error(f"Expected {self.CREATABLES}")
- return
+ return None
return self.expression(
exp.Drop,
@@ -824,14 +905,14 @@ class Parser(metaclass=_Parser):
cascade=self._match(TokenType.CASCADE),
)
- def _parse_exists(self, not_=False):
+ def _parse_exists(self, not_: bool = False) -> t.Optional[bool]:
return (
self._match(TokenType.IF)
and (not not_ or self._match(TokenType.NOT))
and self._match(TokenType.EXISTS)
)
- def _parse_create(self):
+ def _parse_create(self) -> t.Optional[exp.Expression]:
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
transient = self._match_text_seq("TRANSIENT")
@@ -846,12 +927,16 @@ class Parser(metaclass=_Parser):
if not create_token:
self.raise_error(f"Expected {self.CREATABLES}")
- return
+ return None
exists = self._parse_exists(not_=True)
this = None
expression = None
properties = None
+ data = None
+ statistics = None
+ no_primary_index = None
+ indexes = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
@@ -868,7 +953,28 @@ class Parser(metaclass=_Parser):
this = self._parse_table(schema=True)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
- expression = self._parse_select(nested=True)
+ expression = self._parse_ddl_select()
+
+ if create_token.token_type == TokenType.TABLE:
+ if self._match_text_seq("WITH", "DATA"):
+ data = True
+ elif self._match_text_seq("WITH", "NO", "DATA"):
+ data = False
+
+ if self._match_text_seq("AND", "STATISTICS"):
+ statistics = True
+ elif self._match_text_seq("AND", "NO", "STATISTICS"):
+ statistics = False
+
+ no_primary_index = self._match_text_seq("NO", "PRIMARY", "INDEX")
+
+ indexes = []
+ while True:
+ index = self._parse_create_table_index()
+ if not index:
+ break
+ else:
+ indexes.append(index)
return self.expression(
exp.Create,
@@ -883,9 +989,13 @@ class Parser(metaclass=_Parser):
replace=replace,
unique=unique,
materialized=materialized,
+ data=data,
+ statistics=statistics,
+ no_primary_index=no_primary_index,
+ indexes=indexes,
)
- def _parse_property(self):
+ def _parse_property(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PROPERTY_PARSERS):
return self.PROPERTY_PARSERS[self._prev.token_type](self)
@@ -906,7 +1016,7 @@ class Parser(metaclass=_Parser):
return None
- def _parse_property_assignment(self, exp_class):
+ def _parse_property_assignment(self, exp_class: t.Type[exp.Expression]) -> exp.Expression:
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(
@@ -914,42 +1024,50 @@ class Parser(metaclass=_Parser):
this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
)
- def _parse_partitioned_by(self):
+ def _parse_partitioned_by(self) -> exp.Expression:
self._match(TokenType.EQ)
return self.expression(
exp.PartitionedByProperty,
this=self._parse_schema() or self._parse_bracket(self._parse_field()),
)
- def _parse_distkey(self):
+ def _parse_distkey(self) -> exp.Expression:
return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
- def _parse_create_like(self):
+ def _parse_create_like(self) -> t.Optional[exp.Expression]:
table = self._parse_table(schema=True)
options = []
while self._match_texts(("INCLUDING", "EXCLUDING")):
+ this = self._prev.text.upper()
+ id_var = self._parse_id_var()
+
+ if not id_var:
+ return None
+
options.append(
self.expression(
exp.Property,
- this=self._prev.text.upper(),
- value=exp.Var(this=self._parse_id_var().this.upper()),
+ this=this,
+ value=exp.Var(this=id_var.this.upper()),
)
)
return self.expression(exp.LikeProperty, this=table, expressions=options)
- def _parse_sortkey(self, compound=False):
+ def _parse_sortkey(self, compound: bool = False) -> exp.Expression:
return self.expression(
exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
)
- def _parse_character_set(self, default=False):
+ def _parse_character_set(self, default: bool = False) -> exp.Expression:
self._match(TokenType.EQ)
return self.expression(
exp.CharacterSetProperty, this=self._parse_var_or_string(), default=default
)
- def _parse_returns(self):
+ def _parse_returns(self) -> exp.Expression:
+ value: t.Optional[exp.Expression]
is_table = self._match(TokenType.TABLE)
+
if is_table:
if self._match(TokenType.LT):
value = self.expression(
@@ -960,13 +1078,13 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
else:
- value = self._parse_schema("TABLE")
+ value = self._parse_schema(exp.Literal.string("TABLE"))
else:
value = self._parse_types()
return self.expression(exp.ReturnsProperty, this=value, is_table=is_table)
- def _parse_properties(self):
+ def _parse_properties(self) -> t.Optional[exp.Expression]:
properties = []
while True:
@@ -978,15 +1096,21 @@ class Parser(metaclass=_Parser):
if properties:
return self.expression(exp.Properties, expressions=properties)
+
return None
- def _parse_describe(self):
- self._match(TokenType.TABLE)
- return self.expression(exp.Describe, this=self._parse_id_var())
+ def _parse_describe(self) -> exp.Expression:
+ kind = self._match_set(self.CREATABLES) and self._prev.text
+ this = self._parse_table()
- def _parse_insert(self):
+ return self.expression(exp.Describe, this=this, kind=kind)
+
+ def _parse_insert(self) -> exp.Expression:
overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL)
+
+ this: t.Optional[exp.Expression]
+
if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
@@ -998,21 +1122,22 @@ class Parser(metaclass=_Parser):
self._match(TokenType.INTO)
self._match(TokenType.TABLE)
this = self._parse_table(schema=True)
+
return self.expression(
exp.Insert,
this=this,
exists=self._parse_exists(),
partition=self._parse_partition(),
- expression=self._parse_select(nested=True),
+ expression=self._parse_ddl_select(),
overwrite=overwrite,
)
- def _parse_row(self):
+ def _parse_row(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.FORMAT):
return None
return self._parse_row_format()
- def _parse_row_format(self, match_row=False):
+ def _parse_row_format(self, match_row: bool = False) -> t.Optional[exp.Expression]:
if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
@@ -1035,9 +1160,10 @@ class Parser(metaclass=_Parser):
kwargs["lines"] = self._parse_string()
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
- return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
- def _parse_load_data(self):
+ return self.expression(exp.RowFormatDelimitedProperty, **kwargs) # type: ignore
+
+ def _parse_load_data(self) -> exp.Expression:
local = self._match(TokenType.LOCAL)
self._match_text_seq("INPATH")
inpath = self._parse_string()
@@ -1055,7 +1181,7 @@ class Parser(metaclass=_Parser):
serde=self._match_text_seq("SERDE") and self._parse_string(),
)
- def _parse_delete(self):
+ def _parse_delete(self) -> exp.Expression:
self._match(TokenType.FROM)
return self.expression(
@@ -1065,10 +1191,10 @@ class Parser(metaclass=_Parser):
where=self._parse_where(),
)
- def _parse_update(self):
+ def _parse_update(self) -> exp.Expression:
return self.expression(
exp.Update,
- **{
+ **{ # type: ignore
"this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
@@ -1076,16 +1202,17 @@ class Parser(metaclass=_Parser):
},
)
- def _parse_uncache(self):
+ def _parse_uncache(self) -> exp.Expression:
if not self._match(TokenType.TABLE):
self.raise_error("Expecting TABLE after UNCACHE")
+
return self.expression(
exp.Uncache,
exists=self._parse_exists(),
this=self._parse_table(schema=True),
)
- def _parse_cache(self):
+ def _parse_cache(self) -> exp.Expression:
lazy = self._match(TokenType.LAZY)
self._match(TokenType.TABLE)
table = self._parse_table(schema=True)
@@ -1108,21 +1235,23 @@ class Parser(metaclass=_Parser):
expression=self._parse_select(nested=True),
)
- def _parse_partition(self):
+ def _parse_partition(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.PARTITION):
return None
- def parse_values():
+ def parse_values() -> exp.Property:
props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
- def _parse_value(self):
+ def _parse_value(self) -> exp.Expression:
expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
- def _parse_select(self, nested=False, table=False):
+ def _parse_select(
+ self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
+ ) -> t.Optional[exp.Expression]:
cte = self._parse_with()
if cte:
this = self._parse_statement()
@@ -1178,10 +1307,11 @@ class Parser(metaclass=_Parser):
self._parse_query_modifiers(this)
this = self._parse_set_operations(this)
self._match_r_paren()
+
# early return so that subquery unions aren't parsed again
# SELECT * FROM (SELECT 1) UNION ALL SELECT 1
# Union ALL should be a property of the top select node, not the subquery
- return self._parse_subquery(this)
+ return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES):
if self._curr.token_type == TokenType.L_PAREN:
# We don't consume the left paren because it's consumed in _parse_value
@@ -1203,7 +1333,7 @@ class Parser(metaclass=_Parser):
return self._parse_set_operations(this)
- def _parse_with(self, skip_with_token=False):
+ def _parse_with(self, skip_with_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_with_token and not self._match(TokenType.WITH):
return None
@@ -1220,7 +1350,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.With, expressions=expressions, recursive=recursive)
- def _parse_cte(self):
+ def _parse_cte(self) -> exp.Expression:
alias = self._parse_table_alias()
if not alias or not alias.this:
self.raise_error("Expected CTE to have alias")
@@ -1234,7 +1364,9 @@ class Parser(metaclass=_Parser):
alias=alias,
)
- def _parse_table_alias(self, alias_tokens=None):
+ def _parse_table_alias(
+ self, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ ) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
alias = self._parse_id_var(
any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
@@ -1251,15 +1383,17 @@ class Parser(metaclass=_Parser):
return self.expression(exp.TableAlias, this=alias, columns=columns)
- def _parse_subquery(self, this):
+ def _parse_subquery(
+ self, this: t.Optional[exp.Expression], parse_alias: bool = True
+ ) -> exp.Expression:
return self.expression(
exp.Subquery,
this=this,
pivots=self._parse_pivots(),
- alias=self._parse_table_alias(),
+ alias=self._parse_table_alias() if parse_alias else None,
)
- def _parse_query_modifiers(self, this):
+ def _parse_query_modifiers(self, this: t.Optional[exp.Expression]) -> None:
if not isinstance(this, self.MODIFIABLES):
return
@@ -1284,15 +1418,16 @@ class Parser(metaclass=_Parser):
if expression:
this.set(key, expression)
- def _parse_hint(self):
+ def _parse_hint(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
return self.expression(exp.Hint, expressions=hints)
+
return None
- def _parse_into(self):
+ def _parse_into(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.INTO):
return None
@@ -1304,14 +1439,15 @@ class Parser(metaclass=_Parser):
exp.Into, this=self._parse_table(schema=True), temporary=temp, unlogged=unlogged
)
- def _parse_from(self):
+ def _parse_from(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.FROM):
return None
+
return self.expression(
exp.From, comments=self._prev_comments, expressions=self._parse_csv(self._parse_table)
)
- def _parse_lateral(self):
+ def _parse_lateral(self) -> t.Optional[exp.Expression]:
outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY)
cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY)
@@ -1334,6 +1470,8 @@ class Parser(metaclass=_Parser):
expression=self._parse_function() or self._parse_id_var(any_token=False),
)
+ table_alias: t.Optional[exp.Expression]
+
if view:
table = self._parse_id_var(any_token=False)
columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else []
@@ -1354,20 +1492,24 @@ class Parser(metaclass=_Parser):
return expression
- def _parse_join_side_and_kind(self):
+ def _parse_join_side_and_kind(
+ self,
+ ) -> t.Tuple[t.Optional[Token], t.Optional[Token], t.Optional[Token]]:
return (
self._match(TokenType.NATURAL) and self._prev,
self._match_set(self.JOIN_SIDES) and self._prev,
self._match_set(self.JOIN_KINDS) and self._prev,
)
- def _parse_join(self, skip_join_token=False):
+ def _parse_join(self, skip_join_token: bool = False) -> t.Optional[exp.Expression]:
natural, side, kind = self._parse_join_side_and_kind()
if not skip_join_token and not self._match(TokenType.JOIN):
return None
- kwargs = {"this": self._parse_table()}
+ kwargs: t.Dict[
+ str, t.Optional[exp.Expression] | bool | str | t.List[t.Optional[exp.Expression]]
+ ] = {"this": self._parse_table()}
if natural:
kwargs["natural"] = True
@@ -1381,12 +1523,13 @@ class Parser(metaclass=_Parser):
elif self._match(TokenType.USING):
kwargs["using"] = self._parse_wrapped_id_vars()
- return self.expression(exp.Join, **kwargs)
+ return self.expression(exp.Join, **kwargs) # type: ignore
- def _parse_index(self):
+ def _parse_index(self) -> exp.Expression:
index = self._parse_id_var()
self._match(TokenType.ON)
self._match(TokenType.TABLE) # hive
+
return self.expression(
exp.Index,
this=index,
@@ -1394,7 +1537,28 @@ class Parser(metaclass=_Parser):
columns=self._parse_expression(),
)
- def _parse_table(self, schema=False, alias_tokens=None):
+ def _parse_create_table_index(self) -> t.Optional[exp.Expression]:
+ unique = self._match(TokenType.UNIQUE)
+ primary = self._match_text_seq("PRIMARY")
+ amp = self._match_text_seq("AMP")
+ if not self._match(TokenType.INDEX):
+ return None
+ index = self._parse_id_var()
+ columns = None
+ if self._curr and self._curr.token_type == TokenType.L_PAREN:
+ columns = self._parse_wrapped_csv(self._parse_column)
+ return self.expression(
+ exp.Index,
+ this=index,
+ columns=columns,
+ unique=unique,
+ primary=primary,
+ amp=amp,
+ )
+
+ def _parse_table(
+ self, schema: bool = False, alias_tokens: t.Optional[t.Collection[TokenType]] = None
+ ) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
@@ -1417,7 +1581,9 @@ class Parser(metaclass=_Parser):
catalog = None
db = None
- table = (not schema and self._parse_function()) or self._parse_id_var(False)
+ table = (not schema and self._parse_function()) or self._parse_id_var(
+ any_token=False, prefix_tokens=self.TABLE_PREFIX_TOKENS
+ )
while self._match(TokenType.DOT):
if catalog:
@@ -1446,6 +1612,14 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
+ if self._match(TokenType.WITH):
+ this.set(
+ "hints",
+ self._parse_wrapped_csv(
+ lambda: self._parse_function() or self._parse_var(any_token=True)
+ ),
+ )
+
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@@ -1455,7 +1629,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_unnest(self):
+ def _parse_unnest(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.UNNEST):
return None
@@ -1473,7 +1647,7 @@ class Parser(metaclass=_Parser):
exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
)
- def _parse_derived_table_values(self):
+ def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
is_derived = self._match_pair(TokenType.L_PAREN, TokenType.VALUES)
if not is_derived and not self._match(TokenType.VALUES):
return None
@@ -1485,7 +1659,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
- def _parse_table_sample(self):
+ def _parse_table_sample(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE_SAMPLE):
return None
@@ -1533,10 +1707,10 @@ class Parser(metaclass=_Parser):
seed=seed,
)
- def _parse_pivots(self):
+ def _parse_pivots(self) -> t.List[t.Optional[exp.Expression]]:
return list(iter(self._parse_pivot, None))
- def _parse_pivot(self):
+ def _parse_pivot(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.PIVOT):
@@ -1572,16 +1746,18 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
- def _parse_where(self, skip_where_token=False):
+ def _parse_where(self, skip_where_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_where_token and not self._match(TokenType.WHERE):
return None
+
return self.expression(
exp.Where, comments=self._prev_comments, this=self._parse_conjunction()
)
- def _parse_group(self, skip_group_by_token=False):
+ def _parse_group(self, skip_group_by_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_group_by_token and not self._match(TokenType.GROUP_BY):
return None
+
return self.expression(
exp.Group,
expressions=self._parse_csv(self._parse_conjunction),
@@ -1590,29 +1766,33 @@ class Parser(metaclass=_Parser):
rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(),
)
- def _parse_grouping_sets(self):
+ def _parse_grouping_sets(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.GROUPING_SETS):
return None
+
return self._parse_wrapped_csv(self._parse_grouping_set)
- def _parse_grouping_set(self):
+ def _parse_grouping_set(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
grouping_set = self._parse_csv(self._parse_id_var)
self._match_r_paren()
return self.expression(exp.Tuple, expressions=grouping_set)
+
return self._parse_id_var()
- def _parse_having(self, skip_having_token=False):
+ def _parse_having(self, skip_having_token: bool = False) -> t.Optional[exp.Expression]:
if not skip_having_token and not self._match(TokenType.HAVING):
return None
return self.expression(exp.Having, this=self._parse_conjunction())
- def _parse_qualify(self):
+ def _parse_qualify(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.QUALIFY):
return None
return self.expression(exp.Qualify, this=self._parse_conjunction())
- def _parse_order(self, this=None, skip_order_token=False):
+ def _parse_order(
+ self, this: t.Optional[exp.Expression] = None, skip_order_token: bool = False
+ ) -> t.Optional[exp.Expression]:
if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this
@@ -1620,12 +1800,14 @@ class Parser(metaclass=_Parser):
exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
)
- def _parse_sort(self, token_type, exp_class):
+ def _parse_sort(
+ self, token_type: TokenType, exp_class: t.Type[exp.Expression]
+ ) -> t.Optional[exp.Expression]:
if not self._match(token_type):
return None
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
- def _parse_ordered(self):
+ def _parse_ordered(self) -> exp.Expression:
this = self._parse_conjunction()
self._match(TokenType.ASC)
is_desc = self._match(TokenType.DESC)
@@ -1647,7 +1829,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Ordered, this=this, desc=desc, nulls_first=nulls_first)
- def _parse_limit(self, this=None, top=False):
+ def _parse_limit(
+ self, this: t.Optional[exp.Expression] = None, top: bool = False
+ ) -> t.Optional[exp.Expression]:
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
@@ -1667,7 +1851,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_offset(self, this=None):
+ def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this
@@ -1675,7 +1859,7 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
- def _parse_set_operations(self, this):
+ def _parse_set_operations(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match_set(self.SET_OPERATIONS):
return this
@@ -1695,19 +1879,19 @@ class Parser(metaclass=_Parser):
expression=self._parse_select(nested=True),
)
- def _parse_expression(self):
+ def _parse_expression(self) -> t.Optional[exp.Expression]:
return self._parse_alias(self._parse_conjunction())
- def _parse_conjunction(self):
+ def _parse_conjunction(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_equality, self.CONJUNCTION)
- def _parse_equality(self):
+ def _parse_equality(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_comparison, self.EQUALITY)
- def _parse_comparison(self):
+ def _parse_comparison(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_range, self.COMPARISON)
- def _parse_range(self):
+ def _parse_range(self) -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
negate = self._match(TokenType.NOT)
@@ -1730,7 +1914,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_is(self, this):
+ def _parse_is(self, this: t.Optional[exp.Expression]) -> exp.Expression:
negate = self._match(TokenType.NOT)
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
@@ -1743,7 +1927,7 @@ class Parser(metaclass=_Parser):
)
return self.expression(exp.Not, this=this) if negate else this
- def _parse_in(self, this):
+ def _parse_in(self, this: t.Optional[exp.Expression]) -> exp.Expression:
unnest = self._parse_unnest()
if unnest:
this = self.expression(exp.In, this=this, unnest=unnest)
@@ -1761,18 +1945,18 @@ class Parser(metaclass=_Parser):
return this
- def _parse_between(self, this):
+ def _parse_between(self, this: exp.Expression) -> exp.Expression:
low = self._parse_bitwise()
self._match(TokenType.AND)
high = self._parse_bitwise()
return self.expression(exp.Between, this=this, low=low, high=high)
- def _parse_escape(self, this):
+ def _parse_escape(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.ESCAPE):
return this
return self.expression(exp.Escape, this=this, expression=self._parse_string())
- def _parse_bitwise(self):
+ def _parse_bitwise(self) -> t.Optional[exp.Expression]:
this = self._parse_term()
while True:
@@ -1795,18 +1979,18 @@ class Parser(metaclass=_Parser):
return this
- def _parse_term(self):
+ def _parse_term(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_factor, self.TERM)
- def _parse_factor(self):
+ def _parse_factor(self) -> t.Optional[exp.Expression]:
return self._parse_tokens(self._parse_unary, self.FACTOR)
- def _parse_unary(self):
+ def _parse_unary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.UNARY_PARSERS):
return self.UNARY_PARSERS[self._prev.token_type](self)
return self._parse_at_time_zone(self._parse_type())
- def _parse_type(self):
+ def _parse_type(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.INTERVAL):
return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
@@ -1824,7 +2008,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_types(self, check_func=False):
+ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
index = self._index
if not self._match_set(self.TYPE_TOKENS):
@@ -1875,7 +2059,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
- value = None
+ value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
@@ -1884,7 +2068,10 @@ class Parser(metaclass=_Parser):
):
value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match(TokenType.WITHOUT_TIME_ZONE):
- value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
+ if type_token == TokenType.TIME:
+ value = exp.DataType(this=exp.DataType.Type.TIME, expressions=expressions)
+ else:
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
maybe_func = maybe_func and value is None
@@ -1912,7 +2099,7 @@ class Parser(metaclass=_Parser):
nested=nested,
)
- def _parse_struct_kwargs(self):
+ def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
@@ -1921,12 +2108,12 @@ class Parser(metaclass=_Parser):
return None
return self.expression(exp.StructKwarg, this=this, expression=data_type)
- def _parse_at_time_zone(self, this):
+ def _parse_at_time_zone(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.AT_TIME_ZONE):
return this
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
- def _parse_column(self):
+ def _parse_column(self) -> t.Optional[exp.Expression]:
this = self._parse_field()
if isinstance(this, exp.Identifier):
this = self.expression(exp.Column, this=this)
@@ -1943,7 +2130,8 @@ class Parser(metaclass=_Parser):
if not field:
self.raise_error("Expected type")
elif op:
- field = exp.Literal.string(self._advance() or self._prev.text)
+ self._advance()
+ field = exp.Literal.string(self._prev.text)
else:
field = self._parse_star() or self._parse_function() or self._parse_id_var()
@@ -1963,7 +2151,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_primary(self):
+ def _parse_primary(self) -> t.Optional[exp.Expression]:
if self._match_set(self.PRIMARY_PARSERS):
token_type = self._prev.token_type
primary = self.PRIMARY_PARSERS[token_type](self, self._prev)
@@ -1995,21 +2183,27 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
- this = self._parse_set_operations(self._parse_subquery(this))
+ this = self._parse_set_operations(
+ self._parse_subquery(this=this, parse_alias=False)
+ )
elif len(expressions) > 1:
this = self.expression(exp.Tuple, expressions=expressions)
else:
this = self.expression(exp.Paren, this=this)
- if comments:
+
+ if this and comments:
this.comments = comments
+
return this
return None
- def _parse_field(self, any_token=False):
+ def _parse_field(self, any_token: bool = False) -> t.Optional[exp.Expression]:
return self._parse_primary() or self._parse_function() or self._parse_id_var(any_token)
- def _parse_function(self, functions=None):
+ def _parse_function(
+ self, functions: t.Optional[t.Dict[str, t.Callable]] = None
+ ) -> t.Optional[exp.Expression]:
if not self._curr:
return None
@@ -2020,7 +2214,9 @@ class Parser(metaclass=_Parser):
if not self._next or self._next.token_type != TokenType.L_PAREN:
if token_type in self.NO_PAREN_FUNCTIONS:
- return self.expression(self._advance() or self.NO_PAREN_FUNCTIONS[token_type])
+ self._advance()
+ return self.expression(self.NO_PAREN_FUNCTIONS[token_type])
+
return None
if token_type not in self.FUNC_TOKENS:
@@ -2049,7 +2245,18 @@ class Parser(metaclass=_Parser):
args = self._parse_csv(self._parse_lambda)
if function:
- this = function(args)
+
+ # Clickhouse supports function calls like foo(x, y)(z), so for these we need to also parse the
+ # second parameter list (i.e. "(z)") and the corresponding function will receive both arg lists.
+ if count_params(function) == 2:
+ params = None
+ if self._match_pair(TokenType.R_PAREN, TokenType.L_PAREN):
+ params = self._parse_csv(self._parse_lambda)
+
+ this = function(args, params)
+ else:
+ this = function(args)
+
self.validate_expression(this, args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
@@ -2057,7 +2264,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren(this)
return self._parse_window(this)
- def _parse_user_defined_function(self):
+ def _parse_user_defined_function(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
while self._match(TokenType.DOT):
@@ -2070,27 +2277,27 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
- def _parse_introducer(self, token):
+ def _parse_introducer(self, token: Token) -> t.Optional[exp.Expression]:
literal = self._parse_primary()
if literal:
return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
- def _parse_national(self, token):
+ def _parse_national(self, token: Token) -> exp.Expression:
return self.expression(exp.National, this=exp.Literal.string(token.text))
- def _parse_session_parameter(self):
+ def _parse_session_parameter(self) -> exp.Expression:
kind = None
this = self._parse_id_var() or self._parse_primary()
- if self._match(TokenType.DOT):
+ if this and self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
return self.expression(exp.SessionParameter, this=this, kind=kind)
- def _parse_udf_kwarg(self):
+ def _parse_udf_kwarg(self) -> t.Optional[exp.Expression]:
this = self._parse_id_var()
kind = self._parse_types()
@@ -2099,7 +2306,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.UserDefinedFunctionKwarg, this=this, kind=kind)
- def _parse_lambda(self):
+ def _parse_lambda(self) -> t.Optional[exp.Expression]:
index = self._index
if self._match(TokenType.L_PAREN):
@@ -2115,6 +2322,8 @@ class Parser(metaclass=_Parser):
self._retreat(index)
+ this: t.Optional[exp.Expression]
+
if self._match(TokenType.DISTINCT):
this = self.expression(
exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
@@ -2129,7 +2338,7 @@ class Parser(metaclass=_Parser):
return self._parse_limit(self._parse_order(this))
- def _parse_schema(self, this=None):
+ def _parse_schema(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
index = self._index
if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
self._retreat(index)
@@ -2140,14 +2349,15 @@ class Parser(metaclass=_Parser):
or self._parse_column_def(self._parse_field(any_token=True))
)
self._match_r_paren()
+
+ if isinstance(this, exp.Literal):
+ this = this.name
+
return self.expression(exp.Schema, this=this, expressions=args)
- def _parse_column_def(self, this):
+ def _parse_column_def(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
kind = self._parse_types()
- if not kind:
- return this
-
constraints = []
while True:
constraint = self._parse_column_constraint()
@@ -2155,9 +2365,12 @@ class Parser(metaclass=_Parser):
break
constraints.append(constraint)
+ if not kind and not constraints:
+ return this
+
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
- def _parse_column_constraint(self):
+ def _parse_column_constraint(self) -> t.Optional[exp.Expression]:
this = self._parse_references()
if this:
@@ -2166,6 +2379,8 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
+ kind: exp.Expression
+
if self._match(TokenType.AUTO_INCREMENT):
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
@@ -2202,7 +2417,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
- def _parse_constraint(self):
+ def _parse_constraint(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.CONSTRAINT):
return self._parse_unnamed_constraint()
@@ -2217,24 +2432,25 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Constraint, this=this, expressions=expressions)
- def _parse_unnamed_constraint(self):
+ def _parse_unnamed_constraint(self) -> t.Optional[exp.Expression]:
if not self._match_set(self.CONSTRAINT_PARSERS):
return None
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
- def _parse_unique(self):
+ def _parse_unique(self) -> exp.Expression:
return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
- def _parse_references(self):
+ def _parse_references(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.REFERENCES):
return None
+
return self.expression(
exp.Reference,
this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(),
)
- def _parse_foreign_key(self):
+ def _parse_foreign_key(self) -> exp.Expression:
expressions = self._parse_wrapped_id_vars()
reference = self._parse_references()
options = {}
@@ -2260,13 +2476,15 @@ class Parser(metaclass=_Parser):
exp.ForeignKey,
expressions=expressions,
reference=reference,
- **options,
+ **options, # type: ignore
)
- def _parse_bracket(self, this):
+ def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if not self._match(TokenType.L_BRACKET):
return this
+ expressions: t.List[t.Optional[exp.Expression]]
+
if self._match(TokenType.COLON):
expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())]
else:
@@ -2284,12 +2502,12 @@ class Parser(metaclass=_Parser):
this.comments = self._prev_comments
return self._parse_bracket(this)
- def _parse_slice(self, this):
+ def _parse_slice(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if self._match(TokenType.COLON):
return self.expression(exp.Slice, this=this, expression=self._parse_conjunction())
return this
- def _parse_case(self):
+ def _parse_case(self) -> t.Optional[exp.Expression]:
ifs = []
default = None
@@ -2311,7 +2529,7 @@ class Parser(metaclass=_Parser):
self.expression(exp.Case, this=expression, ifs=ifs, default=default)
)
- def _parse_if(self):
+ def _parse_if(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.L_PAREN):
args = self._parse_csv(self._parse_conjunction)
this = exp.If.from_arg_list(args)
@@ -2324,9 +2542,10 @@ class Parser(metaclass=_Parser):
false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
self._match(TokenType.END)
this = self.expression(exp.If, this=condition, true=true, false=false)
+
return self._parse_window(this)
- def _parse_extract(self):
+ def _parse_extract(self) -> exp.Expression:
this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
@@ -2337,7 +2556,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
- def _parse_cast(self, strict):
+ def _parse_cast(self, strict: bool) -> exp.Expression:
this = self._parse_conjunction()
if not self._match(TokenType.ALIAS):
@@ -2353,7 +2572,9 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_string_agg(self):
+ def _parse_string_agg(self) -> exp.Expression:
+ expression: t.Optional[exp.Expression]
+
if self._match(TokenType.DISTINCT):
args = self._parse_csv(self._parse_conjunction)
expression = self.expression(exp.Distinct, expressions=[seq_get(args, 0)])
@@ -2380,8 +2601,10 @@ class Parser(metaclass=_Parser):
order = self._parse_order(this=expression)
return self.expression(exp.GroupConcat, this=order, separator=seq_get(args, 1))
- def _parse_convert(self, strict):
+ def _parse_convert(self, strict: bool) -> exp.Expression:
+ to: t.Optional[exp.Expression]
this = self._parse_column()
+
if self._match(TokenType.USING):
to = self.expression(exp.CharacterSet, this=self._parse_var())
elif self._match(TokenType.COMMA):
@@ -2390,7 +2613,7 @@ class Parser(metaclass=_Parser):
to = None
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- def _parse_position(self):
+ def _parse_position(self) -> exp.Expression:
args = self._parse_csv(self._parse_bitwise)
if self._match(TokenType.IN):
@@ -2402,11 +2625,11 @@ class Parser(metaclass=_Parser):
return this
- def _parse_join_hint(self, func_name):
+ def _parse_join_hint(self, func_name: str) -> exp.Expression:
args = self._parse_csv(self._parse_table)
return exp.JoinHint(this=func_name.upper(), expressions=args)
- def _parse_substring(self):
+ def _parse_substring(self) -> exp.Expression:
# Postgres supports the form: substring(string [from int] [for int])
# https://www.postgresql.org/docs/9.1/functions-string.html @ Table 9-6
@@ -2422,7 +2645,7 @@ class Parser(metaclass=_Parser):
return this
- def _parse_trim(self):
+ def _parse_trim(self) -> exp.Expression:
# https://www.w3resource.com/sql/character-functions/trim.php
# https://docs.oracle.com/javadb/10.8.3.0/ref/rreftrimfunc.html
@@ -2450,13 +2673,15 @@ class Parser(metaclass=_Parser):
collation=collation,
)
- def _parse_window_clause(self):
+ def _parse_window_clause(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window)
- def _parse_named_window(self):
+ def _parse_named_window(self) -> t.Optional[exp.Expression]:
return self._parse_window(self._parse_id_var(), alias=True)
- def _parse_window(self, this, alias=False):
+ def _parse_window(
+ self, this: t.Optional[exp.Expression], alias: bool = False
+ ) -> t.Optional[exp.Expression]:
if self._match(TokenType.FILTER):
where = self._parse_wrapped(self._parse_where)
this = self.expression(exp.Filter, this=this, expression=where)
@@ -2495,7 +2720,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
- alias = self._parse_id_var(False)
+ window_alias = self._parse_id_var(any_token=False, tokens=self.WINDOW_ALIAS_TOKENS)
partition = None
if self._match(TokenType.PARTITION_BY):
@@ -2529,10 +2754,10 @@ class Parser(metaclass=_Parser):
partition_by=partition,
order=order,
spec=spec,
- alias=alias,
+ alias=window_alias,
)
- def _parse_window_spec(self):
+ def _parse_window_spec(self) -> t.Dict[str, t.Optional[str | exp.Expression]]:
self._match(TokenType.BETWEEN)
return {
@@ -2543,7 +2768,9 @@ class Parser(metaclass=_Parser):
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
}
- def _parse_alias(self, this, explicit=False):
+ def _parse_alias(
+ self, this: t.Optional[exp.Expression], explicit: bool = False
+ ) -> t.Optional[exp.Expression]:
any_token = self._match(TokenType.ALIAS)
if explicit and not any_token:
@@ -2565,63 +2792,74 @@ class Parser(metaclass=_Parser):
return this
- def _parse_id_var(self, any_token=True, tokens=None):
+ def _parse_id_var(
+ self,
+ any_token: bool = True,
+ tokens: t.Optional[t.Collection[TokenType]] = None,
+ prefix_tokens: t.Optional[t.Collection[TokenType]] = None,
+ ) -> t.Optional[exp.Expression]:
identifier = self._parse_identifier()
if identifier:
return identifier
+ prefix = ""
+
+ if prefix_tokens:
+ while self._match_set(prefix_tokens):
+ prefix += self._prev.text
+
if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS):
- return exp.Identifier(this=self._prev.text, quoted=False)
+ return exp.Identifier(this=prefix + self._prev.text, quoted=False)
return None
- def _parse_string(self):
+ def _parse_string(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.STRING):
return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
- def _parse_number(self):
+ def _parse_number(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NUMBER):
return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
return self._parse_placeholder()
- def _parse_identifier(self):
+ def _parse_identifier(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.IDENTIFIER):
return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
- def _parse_var(self, any_token=False):
+ def _parse_var(self, any_token: bool = False) -> t.Optional[exp.Expression]:
if (any_token and self._advance_any()) or self._match(TokenType.VAR):
return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
- def _advance_any(self):
+ def _advance_any(self) -> t.Optional[Token]:
if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
self._advance()
return self._prev
return None
- def _parse_var_or_string(self):
+ def _parse_var_or_string(self) -> t.Optional[exp.Expression]:
return self._parse_var() or self._parse_string()
- def _parse_null(self):
+ def _parse_null(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.NULL):
return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
return None
- def _parse_boolean(self):
+ def _parse_boolean(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.TRUE):
return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE):
return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
return None
- def _parse_star(self):
+ def _parse_star(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.STAR):
return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
- def _parse_placeholder(self):
+ def _parse_placeholder(self) -> t.Optional[exp.Expression]:
if self._match(TokenType.PLACEHOLDER):
return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
@@ -2630,18 +2868,20 @@ class Parser(metaclass=_Parser):
self._advance(-1)
return None
- def _parse_except(self):
+ def _parse_except(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.EXCEPT):
return None
return self._parse_wrapped_id_vars()
- def _parse_replace(self):
+ def _parse_replace(self) -> t.Optional[t.List[t.Optional[exp.Expression]]]:
if not self._match(TokenType.REPLACE):
return None
return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
- def _parse_csv(self, parse_method, sep=TokenType.COMMA):
+ def _parse_csv(
+ self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
+ ) -> t.List[t.Optional[exp.Expression]]:
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
@@ -2655,7 +2895,9 @@ class Parser(metaclass=_Parser):
return items
- def _parse_tokens(self, parse_method, expressions):
+ def _parse_tokens(
+ self, parse_method: t.Callable, expressions: t.Dict
+ ) -> t.Optional[exp.Expression]:
this = parse_method()
while self._match_set(expressions):
@@ -2668,22 +2910,29 @@ class Parser(metaclass=_Parser):
return this
- def _parse_wrapped_id_vars(self):
+ def _parse_wrapped_id_vars(self) -> t.List[t.Optional[exp.Expression]]:
return self._parse_wrapped_csv(self._parse_id_var)
- def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
+ def _parse_wrapped_csv(
+ self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
+ ) -> t.List[t.Optional[exp.Expression]]:
return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
- def _parse_wrapped(self, parse_method):
+ def _parse_wrapped(self, parse_method: t.Callable) -> t.Any:
self._match_l_paren()
parse_result = parse_method()
self._match_r_paren()
return parse_result
- def _parse_select_or_expression(self):
+ def _parse_select_or_expression(self) -> t.Optional[exp.Expression]:
return self._parse_select() or self._parse_expression()
- def _parse_transaction(self):
+ def _parse_ddl_select(self) -> t.Optional[exp.Expression]:
+ return self._parse_set_operations(
+ self._parse_select(nested=True, parse_subquery_alias=False)
+ )
+
+ def _parse_transaction(self) -> exp.Expression:
this = None
if self._match_texts(self.TRANSACTION_KIND):
this = self._prev.text
@@ -2703,7 +2952,7 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Transaction, this=this, modes=modes)
- def _parse_commit_or_rollback(self):
+ def _parse_commit_or_rollback(self) -> exp.Expression:
chain = None
savepoint = None
is_rollback = self._prev.token_type == TokenType.ROLLBACK
@@ -2722,27 +2971,30 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Rollback, savepoint=savepoint)
return self.expression(exp.Commit, chain=chain)
- def _parse_add_column(self):
+ def _parse_add_column(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("ADD"):
return None
self._match(TokenType.COLUMN)
exists_column = self._parse_exists(not_=True)
expression = self._parse_column_def(self._parse_field(any_token=True))
- expression.set("exists", exists_column)
+
+ if expression:
+ expression.set("exists", exists_column)
+
return expression
- def _parse_drop_column(self):
+ def _parse_drop_column(self) -> t.Optional[exp.Expression]:
return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN")
- def _parse_alter(self):
+ def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return None
exists = self._parse_exists()
this = self._parse_table(schema=True)
- actions = None
+ actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
if self._match_text_seq("ADD", advance=False):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
@@ -2770,24 +3022,24 @@ class Parser(metaclass=_Parser):
actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
- def _parse_show(self):
- parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
+ def _parse_show(self) -> t.Optional[exp.Expression]:
+ parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) # type: ignore
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())
- def _default_parse_set_item(self):
+ def _default_parse_set_item(self) -> exp.Expression:
return self.expression(
exp.SetItem,
this=self._parse_statement(),
)
- def _parse_set_item(self):
- parser = self._find_parser(self.SET_PARSERS, self._set_trie)
+ def _parse_set_item(self) -> t.Optional[exp.Expression]:
+ parser = self._find_parser(self.SET_PARSERS, self._set_trie) # type: ignore
return parser(self) if parser else self._default_parse_set_item()
- def _parse_merge(self):
+ def _parse_merge(self) -> exp.Expression:
self._match(TokenType.INTO)
target = self._parse_table(schema=True)
@@ -2835,10 +3087,12 @@ class Parser(metaclass=_Parser):
expressions=whens,
)
- def _parse_set(self):
+ def _parse_set(self) -> exp.Expression:
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
- def _find_parser(self, parsers, trie):
+ def _find_parser(
+ self, parsers: t.Dict[str, t.Callable], trie: t.Dict
+ ) -> t.Optional[t.Callable]:
index = self._index
this = []
while True:
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index d9a4004..a0d69a7 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import abc
import typing as t
+import sqlglot
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
from sqlglot.helper import dict_depth
@@ -157,10 +158,10 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
- super().__init__(schema)
- self.visible = visible or {}
self.dialect = dialect
+ self.visible = visible or {}
self._type_mapping_cache: t.Dict[str, exp.DataType] = {}
+ super().__init__(self._normalize(schema or {}))
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
@@ -180,6 +181,33 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
}
)
+ def _normalize(self, schema: t.Dict) -> t.Dict:
+ """
+ Converts all identifiers in the schema into lowercase, unless they're quoted.
+
+ Args:
+ schema: the schema to normalize.
+
+ Returns:
+ The normalized schema mapping.
+ """
+ flattened_schema = flatten_schema(schema, depth=dict_depth(schema) - 1)
+
+ normalized_mapping: t.Dict = {}
+ for keys in flattened_schema:
+ columns = _nested_get(schema, *zip(keys, keys))
+ assert columns is not None
+
+ normalized_keys = [self._normalize_name(key) for key in keys]
+ for column_name, column_type in columns.items():
+ _nested_set(
+ normalized_mapping,
+ normalized_keys + [self._normalize_name(column_name)],
+ column_type,
+ )
+
+ return normalized_mapping
+
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@@ -204,6 +232,19 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
)
self.mapping_trie = self._build_trie(self.mapping)
+ def _normalize_name(self, name: str) -> str:
+ try:
+ identifier: t.Optional[exp.Expression] = sqlglot.parse_one(
+ name, read=self.dialect, into=exp.Identifier
+ )
+ except:
+ identifier = exp.to_identifier(name)
+ assert isinstance(identifier, exp.Identifier)
+
+ if identifier.quoted:
+ return identifier.name
+ return identifier.name.lower()
+
def _depth(self) -> int:
# The columns themselves are a mapping, but we don't want to include those
return super()._depth() - 1
diff --git a/sqlglot/serde.py b/sqlglot/serde.py
new file mode 100644
index 0000000..a47ffdb
--- /dev/null
+++ b/sqlglot/serde.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import expressions as exp
+
+if t.TYPE_CHECKING:
+ JSON = t.Union[dict, list, str, float, int, bool]
+ Node = t.Union[t.List["Node"], exp.DataType.Type, exp.Expression, JSON]
+
+
+def dump(node: Node) -> JSON:
+ """
+ Recursively dump an AST into a JSON-serializable dict.
+ """
+ if isinstance(node, list):
+ return [dump(i) for i in node]
+ if isinstance(node, exp.DataType.Type):
+ return {
+ "class": "DataType.Type",
+ "value": node.value,
+ }
+ if isinstance(node, exp.Expression):
+ klass = node.__class__.__qualname__
+ if node.__class__.__module__ != exp.__name__:
+ klass = f"{node.__module__}.{klass}"
+ obj = {
+ "class": klass,
+ "args": {k: dump(v) for k, v in node.args.items() if v is not None and v != []},
+ }
+ if node.type:
+ obj["type"] = node.type.sql()
+ if node.comments:
+ obj["comments"] = node.comments
+ return obj
+ return node
+
+
+def load(obj: JSON) -> Node:
+ """
+ Recursively load a dict (as returned by `dump`) into an AST.
+ """
+ if isinstance(obj, list):
+ return [load(i) for i in obj]
+ if isinstance(obj, dict):
+ class_name = obj["class"]
+
+ if class_name == "DataType.Type":
+ return exp.DataType.Type(obj["value"])
+
+ if "." in class_name:
+ module_path, class_name = class_name.rsplit(".", maxsplit=1)
+ module = __import__(module_path, fromlist=[class_name])
+ else:
+ module = exp
+
+ klass = getattr(module, class_name)
+
+ expression = klass(**{k: load(v) for k, v in obj["args"].items()})
+ type_ = obj.get("type")
+ if type_:
+ expression.type = exp.DataType.build(type_)
+ comments = obj.get("comments")
+ if comments:
+ expression.comments = load(comments)
+ return expression
+ return obj
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 0efa7d0..8e312a7 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -86,6 +86,7 @@ class TokenType(AutoName):
VARBINARY = auto()
JSON = auto()
JSONB = auto()
+ TIME = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
TIMESTAMPLTZ = auto()
@@ -181,6 +182,7 @@ class TokenType(AutoName):
FUNCTION = auto()
FROM = auto()
GENERATED = auto()
+ GLOBAL = auto()
GROUP_BY = auto()
GROUPING_SETS = auto()
HAVING = auto()
@@ -656,6 +658,7 @@ class Tokenizer(metaclass=_Tokenizer):
"FLOAT4": TokenType.FLOAT,
"FLOAT8": TokenType.DOUBLE,
"DOUBLE": TokenType.DOUBLE,
+ "DOUBLE PRECISION": TokenType.DOUBLE,
"JSON": TokenType.JSON,
"CHAR": TokenType.CHAR,
"NCHAR": TokenType.NCHAR,
@@ -671,6 +674,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BLOB": TokenType.VARBINARY,
"BYTEA": TokenType.VARBINARY,
"VARBINARY": TokenType.VARBINARY,
+ "TIME": TokenType.TIME,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@@ -721,6 +725,8 @@ class Tokenizer(metaclass=_Tokenizer):
COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled
+ IDENTIFIER_CAN_START_WITH_DIGIT = False
+
__slots__ = (
"sql",
"size",
@@ -938,17 +944,24 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1
self._advance()
- elif self._peek.isalpha(): # type: ignore
- self._add(TokenType.NUMBER)
+ elif self._peek.isidentifier(): # type: ignore
+ number_text = self._text
literal = []
- while self._peek.isalpha(): # type: ignore
+ while self._peek.isidentifier(): # type: ignore
literal.append(self._peek.upper()) # type: ignore
self._advance()
+
literal = "".join(literal) # type: ignore
token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
+
if token_type:
+ self._add(TokenType.NUMBER, number_text)
self._add(TokenType.DCOLON, "::")
return self._add(token_type, literal) # type: ignore
+ elif self.IDENTIFIER_CAN_START_WITH_DIGIT:
+ return self._add(TokenType.VAR)
+
+ self._add(TokenType.NUMBER, number_text)
return self._advance(-len(literal))
else:
return self._add(TokenType.NUMBER)
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 99949a1..35ff75a 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -82,6 +82,27 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
return expression
+def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression:
+ """
+ Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions.
+ This transforms removes the precision from parameterized types in expressions.
+ """
+ return expression.transform(
+ lambda node: exp.DataType(
+ **{
+ **node.args,
+ "expressions": [
+ node_expression
+ for node_expression in node.expressions
+ if isinstance(node_expression, exp.DataType)
+ ],
+ }
+ )
+ if isinstance(node, exp.DataType)
+ else node,
+ )
+
+
def preprocess(
transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
to_sql: t.Callable[[Generator, exp.Expression], str],
@@ -121,3 +142,6 @@ def delegate(attr: str) -> t.Callable:
UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
ELIMINATE_DISTINCT_ON = {exp.Select: preprocess([eliminate_distinct_on], delegate("select_sql"))}
+REMOVE_PRECISION_PARAMETERIZED_TYPES = {
+ exp.Cast: preprocess([remove_precision_parameterized_types], delegate("cast_sql"))
+}
diff --git a/sqlglot/trie.py b/sqlglot/trie.py
index fa2aaf1..f3b1c38 100644
--- a/sqlglot/trie.py
+++ b/sqlglot/trie.py
@@ -52,7 +52,7 @@ def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
Returns:
A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
- is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
+ is either 0 (search was unsuccessful), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
"""
if not key:
return (0, trie)