summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-12 15:42:38 +0000
commitbea2635be022e272ddac349f5e396ec901fc37e5 (patch)
tree24dbe11c9d462ff55f9b3af4b4da4cd1ae02e8a3 /sqlglot
parentReleasing debian version 10.1.3-1. (diff)
downloadsqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.tar.xz
sqlglot-bea2635be022e272ddac349f5e396ec901fc37e5.zip
Merging upstream version 10.2.6.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/dataframe.py2
-rw-r--r--sqlglot/dialects/bigquery.py33
-rw-r--r--sqlglot/dialects/hive.py15
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/redshift.py10
-rw-r--r--sqlglot/dialects/snowflake.py1
-rw-r--r--sqlglot/executor/env.py9
-rw-r--r--sqlglot/executor/python.py4
-rw-r--r--sqlglot/expressions.py106
-rw-r--r--sqlglot/generator.py452
-rw-r--r--sqlglot/helper.py8
-rw-r--r--sqlglot/optimizer/annotate_types.py37
-rw-r--r--sqlglot/optimizer/canonicalize.py25
-rw-r--r--sqlglot/optimizer/simplify.py235
-rw-r--r--sqlglot/parser.py136
-rw-r--r--sqlglot/schema.py22
-rw-r--r--sqlglot/tokens.py19
18 files changed, 747 insertions, 370 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index b027ac7..3733b20 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.1.3"
+__version__ = "10.2.6"
pretty = False
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 548c322..3c45741 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -317,7 +317,7 @@ class DataFrame:
sqlglot.schema.add_table(
cache_table_name,
{
- expression.alias_or_name: expression.type.name
+ expression.alias_or_name: expression.type.sql("spark")
for expression in select_expression.expressions
},
)
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5b44912..6be68ac 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -110,17 +110,17 @@ class BigQuery(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "BEGIN": TokenType.COMMAND,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
- "INT64": TokenType.BIGINT,
"FLOAT64": TokenType.DOUBLE,
+ "INT64": TokenType.BIGINT,
+ "NOT DETERMINISTIC": TokenType.VOLATILE,
"QUALIFY": TokenType.QUALIFY,
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
- "NOT DETERMINISTIC": TokenType.VOLATILE,
- "BEGIN": TokenType.COMMAND,
- "BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@@ -131,6 +131,7 @@ class BigQuery(Dialect):
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
"DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
+ "REGEXP_CONTAINS": exp.RegexpLike.from_arg_list,
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
@@ -144,6 +145,7 @@ class BigQuery(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS,
+ "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]),
}
FUNCTION_PARSERS.pop("TRIM")
@@ -161,7 +163,6 @@ class BigQuery(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
- exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
exp.DateSub: _date_add_sql("DATE", "SUB"),
@@ -183,6 +184,7 @@ class BigQuery(Dialect):
exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
if e.name == "IMMUTABLE"
else "NOT DETERMINISTIC",
+ exp.RegexpLike: rename_func("REGEXP_CONTAINS"),
}
TYPE_MAPPING = {
@@ -210,24 +212,31 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
- def transaction_sql(self, *_):
+ def array_sql(self, expression: exp.Array) -> str:
+ first_arg = seq_get(expression.expressions, 0)
+ if isinstance(first_arg, exp.Subqueryable):
+ return f"ARRAY{self.wrap(self.sql(first_arg))}"
+
+ return inline_array_sql(self, expression)
+
+ def transaction_sql(self, *_) -> str:
return "BEGIN TRANSACTION"
- def commit_sql(self, *_):
+ def commit_sql(self, *_) -> str:
return "COMMIT TRANSACTION"
- def rollback_sql(self, *_):
+ def rollback_sql(self, *_) -> str:
return "ROLLBACK TRANSACTION"
- def in_unnest_op(self, unnest):
- return self.sql(unnest)
+ def in_unnest_op(self, expression: exp.Unnest) -> str:
+ return self.sql(expression)
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
if not expression.args.get("distinct", False):
self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
if not expression.args.get("distinct", False):
self.unsupported("INTERSECT without DISTINCT is not supported in BigQuery")
return f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index cbb39c2..70c1c6c 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -190,6 +190,7 @@ class Hive(Dialect):
"ADD FILES": TokenType.COMMAND,
"ADD JAR": TokenType.COMMAND,
"ADD JARS": TokenType.COMMAND,
+ "WITH SERDEPROPERTIES": TokenType.SERDE_PROPERTIES,
}
class Parser(parser.Parser):
@@ -238,6 +239,13 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
+ PROPERTY_PARSERS = {
+ **parser.Parser.PROPERTY_PARSERS,
+ TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties(
+ expressions=self._parse_wrapped_csv(self._parse_property)
+ ),
+ }
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -297,6 +305,8 @@ class Hive(Dialect):
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'this')}",
+ exp.RowFormatSerdeProperty: lambda self, e: f"ROW FORMAT SERDE {self.sql(e, 'this')}",
+ exp.SerdeProperties: lambda self, e: self.properties(e, prefix="WITH SERDEPROPERTIES"),
exp.NumberToStr: rename_func("FORMAT_NUMBER"),
}
@@ -308,12 +318,15 @@ class Hive(Dialect):
exp.SchemaCommentProperty,
exp.LocationProperty,
exp.TableFormatProperty,
+ exp.RowFormatDelimitedProperty,
+ exp.RowFormatSerdeProperty,
+ exp.SerdeProperties,
}
def with_properties(self, properties):
return self.properties(
properties,
- prefix="TBLPROPERTIES",
+ prefix=self.seg("TBLPROPERTIES"),
)
def datatype_sql(self, expression):
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index ceaf9ba..f507513 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -98,6 +98,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "MINUS": TokenType.EXCEPT,
"START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index cd50979..55ed0a6 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from sqlglot import exp, transforms
+from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@@ -13,12 +14,20 @@ class Redshift(Postgres):
"HH": "%H",
}
+ class Parser(Postgres.Parser):
+ FUNCTIONS = {
+ **Postgres.Parser.FUNCTIONS, # type: ignore
+ "DECODE": exp.Matches.from_arg_list,
+ "NVL": exp.Coalesce.from_arg_list,
+ }
+
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS, # type: ignore
"COPY": TokenType.COMMAND,
+ "ENCODE": TokenType.ENCODE,
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
@@ -50,4 +59,5 @@ class Redshift(Postgres):
exp.DistKeyProperty: lambda self, e: f"DISTKEY({e.name})",
exp.SortKeyProperty: lambda self, e: f"{'COMPOUND ' if e.args['compound'] else ''}SORTKEY({self.format_args(*e.this)})",
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
+ exp.Matches: rename_func("DECODE"),
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 46155ff..75dc9dc 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -198,6 +198,7 @@ class Snowflake(Dialect):
"TIMESTAMP_NTZ": TokenType.TIMESTAMP,
"TIMESTAMP_TZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPNTZ": TokenType.TIMESTAMP,
+ "MINUS": TokenType.EXCEPT,
"SAMPLE": TokenType.TABLE_SAMPLE,
}
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index e6cfcdd..ad9397e 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -19,10 +19,13 @@ class reverse_key:
return other.obj < self.obj
-def filter_nulls(func):
+def filter_nulls(func, empty_null=True):
@wraps(func)
def _func(values):
- return func(v for v in values if v is not None)
+ filtered = tuple(v for v in values if v is not None)
+ if not filtered and empty_null:
+ return None
+ return func(filtered)
return _func
@@ -126,7 +129,7 @@ ENV = {
# aggs
"SUM": filter_nulls(sum),
"AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
- "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
+ "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False),
"MAX": filter_nulls(max),
"MIN": filter_nulls(min),
# scalar functions
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 908b80a..9f22c45 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -310,9 +310,9 @@ class PythonExecutor:
if i == length - 1:
context.set_range(start, end - 1)
add_row()
- elif step.limit > 0:
+ elif step.limit > 0 and not group_by:
context.set_range(0, 0)
- table.append(context.eval_tuple(group_by) + context.eval_tuple(aggregations))
+ table.append(context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 96b32f1..7249574 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -43,14 +43,14 @@ class Expression(metaclass=_Expression):
key = "Expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "type", "comments")
+ __slots__ = ("args", "parent", "arg_key", "comments", "_type")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
- self.type = None
self.comments = None
+ self._type: t.Optional[DataType] = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
@@ -122,6 +122,16 @@ class Expression(metaclass=_Expression):
return "NULL"
return self.alias or self.name
+ @property
+ def type(self) -> t.Optional[DataType]:
+ return self._type
+
+ @type.setter
+ def type(self, dtype: t.Optional[DataType | DataType.Type | str]) -> None:
+ if dtype and not isinstance(dtype, DataType):
+ dtype = DataType.build(dtype)
+ self._type = dtype # type: ignore
+
def __deepcopy__(self, memo):
copy = self.__class__(**deepcopy(self.args))
copy.comments = self.comments
@@ -348,7 +358,7 @@ class Expression(metaclass=_Expression):
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
- args = {
+ args: t.Dict[str, t.Any] = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
for v in ensure_collection(vs)
@@ -612,6 +622,7 @@ class Create(Expression):
"properties": False,
"temporary": False,
"transient": False,
+ "external": False,
"replace": False,
"unique": False,
"materialized": False,
@@ -744,13 +755,17 @@ class DefaultColumnConstraint(ColumnConstraintKind):
pass
+class EncodeColumnConstraint(ColumnConstraintKind):
+ pass
+
+
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
arg_types = {"this": True, "expression": False}
class NotNullColumnConstraint(ColumnConstraintKind):
- pass
+ arg_types = {"allow_null": False}
class PrimaryKeyColumnConstraint(ColumnConstraintKind):
@@ -766,7 +781,7 @@ class Constraint(Expression):
class Delete(Expression):
- arg_types = {"with": False, "this": True, "using": False, "where": False}
+ arg_types = {"with": False, "this": False, "using": False, "where": False}
class Drop(Expression):
@@ -850,7 +865,7 @@ class Insert(Expression):
arg_types = {
"with": False,
"this": True,
- "expression": True,
+ "expression": False,
"overwrite": False,
"exists": False,
"partition": False,
@@ -1125,6 +1140,27 @@ class VolatilityProperty(Property):
arg_types = {"this": True}
+class RowFormatDelimitedProperty(Property):
+ # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
+ arg_types = {
+ "fields": False,
+ "escaped": False,
+ "collection_items": False,
+ "map_keys": False,
+ "lines": False,
+ "null": False,
+ "serde": False,
+ }
+
+
+class RowFormatSerdeProperty(Property):
+ arg_types = {"this": True}
+
+
+class SerdeProperties(Property):
+ arg_types = {"expressions": True}
+
+
class Properties(Expression):
arg_types = {"expressions": True}
@@ -1169,18 +1205,6 @@ class Reference(Expression):
arg_types = {"this": True, "expressions": True}
-class RowFormat(Expression):
- # https://cwiki.apache.org/confluence/display/hive/languagemanual+dml
- arg_types = {
- "fields": False,
- "escaped": False,
- "collection_items": False,
- "map_keys": False,
- "lines": False,
- "null": False,
- }
-
-
class Tuple(Expression):
arg_types = {"expressions": False}
@@ -1208,6 +1232,9 @@ class Subqueryable(Unionable):
alias=TableAlias(this=to_identifier(alias)),
)
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ raise NotImplementedError
+
@property
def ctes(self):
with_ = self.args.get("with")
@@ -1320,6 +1347,32 @@ class Union(Subqueryable):
**QUERY_MODIFIERS,
}
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
+ """
+ Set the LIMIT expression.
+
+ Example:
+ >>> select("1").union(select("1")).limit(1).sql()
+ 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1'
+
+ Args:
+ expression (str | int | Expression): the SQL code string to parse.
+ This can also be an integer.
+ If a `Limit` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Limit`.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: The limited subqueryable.
+ """
+ return (
+ select("*")
+ .from_(self.subquery(alias="_l_0", copy=copy))
+ .limit(expression, dialect=dialect, copy=False, **opts)
+ )
+
@property
def named_selects(self):
return self.this.unnest().named_selects
@@ -1356,7 +1409,7 @@ class Unnest(UDTF):
class Update(Expression):
arg_types = {
"with": False,
- "this": True,
+ "this": False,
"expressions": True,
"from": False,
"where": False,
@@ -2057,15 +2110,20 @@ class DataType(Expression):
Type.TEXT,
}
- NUMERIC_TYPES = {
+ INTEGER_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
+ }
+
+ FLOAT_TYPES = {
Type.FLOAT,
Type.DOUBLE,
}
+ NUMERIC_TYPES = {*INTEGER_TYPES, *FLOAT_TYPES}
+
TEMPORAL_TYPES = {
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
@@ -2968,6 +3026,14 @@ class Use(Expression):
pass
+class Merge(Expression):
+ arg_types = {"this": True, "using": True, "on": True, "expressions": True}
+
+
+class When(Func):
+ arg_types = {"this": True, "then": True}
+
+
def _norm_args(expression):
args = {}
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 47774fc..beffb91 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -189,12 +189,12 @@ class Generator:
self._max_text_width = max_text_width
self._comments = comments
- def generate(self, expression):
+ def generate(self, expression: t.Optional[exp.Expression]) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
- expression (Expression): the syntax tree.
+ expression: the syntax tree.
Returns
the SQL string.
@@ -213,23 +213,23 @@ class Generator:
return sql
- def unsupported(self, message):
+ def unsupported(self, message: str) -> None:
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
- def sep(self, sep=" "):
+ def sep(self, sep: str = " ") -> str:
return f"{sep.strip()}\n" if self.pretty else sep
- def seg(self, sql, sep=" "):
+ def seg(self, sql: str, sep: str = " ") -> str:
return f"{self.sep(sep)}{sql}"
- def pad_comment(self, comment):
+ def pad_comment(self, comment: str) -> str:
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
return comment
- def maybe_comment(self, sql, expression):
+ def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
comments = expression.comments if self._comments else None
if not comments:
@@ -243,7 +243,7 @@ class Generator:
return f"{sql} {comments}"
- def wrap(self, expression):
+ def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
self.sql(expression)
if isinstance(expression, (exp.Select, exp.Union))
@@ -253,21 +253,28 @@ class Generator:
)
return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
- def no_identify(self, func):
+ def no_identify(self, func: t.Callable[[], str]) -> str:
original = self.identify
self.identify = False
result = func()
self.identify = original
return result
- def normalize_func(self, name):
+ def normalize_func(self, name: str) -> str:
if self.normalize_functions == "upper":
return name.upper()
if self.normalize_functions == "lower":
return name.lower()
return name
- def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False):
+ def indent(
+ self,
+ sql: str,
+ level: int = 0,
+ pad: t.Optional[int] = None,
+ skip_first: bool = False,
+ skip_last: bool = False,
+ ) -> str:
if not self.pretty:
return sql
@@ -281,7 +288,12 @@ class Generator:
for i, line in enumerate(lines)
)
- def sql(self, expression, key=None, comment=True):
+ def sql(
+ self,
+ expression: t.Optional[str | exp.Expression],
+ key: t.Optional[str] = None,
+ comment: bool = True,
+ ) -> str:
if not expression:
return ""
@@ -313,12 +325,12 @@ class Generator:
return self.maybe_comment(sql, expression) if self._comments and comment else sql
- def uncache_sql(self, expression):
+ def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
exists_sql = " IF EXISTS" if expression.args.get("exists") else ""
return f"UNCACHE TABLE{exists_sql} {table}"
- def cache_sql(self, expression):
+ def cache_sql(self, expression: exp.Cache) -> str:
lazy = " LAZY" if expression.args.get("lazy") else ""
table = self.sql(expression, "this")
options = expression.args.get("options")
@@ -328,13 +340,13 @@ class Generator:
sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
return self.prepend_ctes(expression, sql)
- def characterset_sql(self, expression):
+ def characterset_sql(self, expression: exp.CharacterSet) -> str:
if isinstance(expression.parent, exp.Cast):
return f"CHAR CHARACTER SET {self.sql(expression, 'this')}"
default = "DEFAULT " if expression.args.get("default") else ""
return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
- def column_sql(self, expression):
+ def column_sql(self, expression: exp.Column) -> str:
return ".".join(
part
for part in [
@@ -345,7 +357,7 @@ class Generator:
if part
)
- def columndef_sql(self, expression):
+ def columndef_sql(self, expression: exp.ColumnDef) -> str:
column = self.sql(expression, "this")
kind = self.sql(expression, "kind")
constraints = self.expressions(expression, key="constraints", sep=" ", flat=True)
@@ -354,46 +366,52 @@ class Generator:
return f"{column} {kind}"
return f"{column} {kind} {constraints}"
- def columnconstraint_sql(self, expression):
+ def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str:
this = self.sql(expression, "this")
kind_sql = self.sql(expression, "kind")
return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
- def autoincrementcolumnconstraint_sql(self, _):
+ def autoincrementcolumnconstraint_sql(self, _) -> str:
return self.token_sql(TokenType.AUTO_INCREMENT)
- def checkcolumnconstraint_sql(self, expression):
+ def checkcolumnconstraint_sql(self, expression: exp.CheckColumnConstraint) -> str:
this = self.sql(expression, "this")
return f"CHECK ({this})"
- def commentcolumnconstraint_sql(self, expression):
+ def commentcolumnconstraint_sql(self, expression: exp.CommentColumnConstraint) -> str:
comment = self.sql(expression, "this")
return f"COMMENT {comment}"
- def collatecolumnconstraint_sql(self, expression):
+ def collatecolumnconstraint_sql(self, expression: exp.CollateColumnConstraint) -> str:
collate = self.sql(expression, "this")
return f"COLLATE {collate}"
- def defaultcolumnconstraint_sql(self, expression):
+ def encodecolumnconstraint_sql(self, expression: exp.EncodeColumnConstraint) -> str:
+ encode = self.sql(expression, "this")
+ return f"ENCODE {encode}"
+
+ def defaultcolumnconstraint_sql(self, expression: exp.DefaultColumnConstraint) -> str:
default = self.sql(expression, "this")
return f"DEFAULT {default}"
- def generatedasidentitycolumnconstraint_sql(self, expression):
+ def generatedasidentitycolumnconstraint_sql(
+ self, expression: exp.GeneratedAsIdentityColumnConstraint
+ ) -> str:
return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
- def notnullcolumnconstraint_sql(self, _):
- return "NOT NULL"
+ def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
+ return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
- def primarykeycolumnconstraint_sql(self, expression):
+ def primarykeycolumnconstraint_sql(self, expression: exp.PrimaryKeyColumnConstraint) -> str:
desc = expression.args.get("desc")
if desc is not None:
return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
return f"PRIMARY KEY"
- def uniquecolumnconstraint_sql(self, _):
+ def uniquecolumnconstraint_sql(self, _) -> str:
return "UNIQUE"
- def create_sql(self, expression):
+ def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
expression_sql = self.sql(expression, "expression")
@@ -402,47 +420,58 @@ class Generator:
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
)
+ external = " EXTERNAL" if expression.args.get("external") else ""
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
properties = self.sql(expression, "properties")
- expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}"
+ modifiers = "".join(
+ (
+ replace,
+ temporary,
+ transient,
+ external,
+ unique,
+ materialized,
+ )
+ )
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties} {expression_sql}"
return self.prepend_ctes(expression, expression_sql)
- def describe_sql(self, expression):
+ def describe_sql(self, expression: exp.Describe) -> str:
return f"DESCRIBE {self.sql(expression, 'this')}"
- def prepend_ctes(self, expression, sql):
+ def prepend_ctes(self, expression: exp.Expression, sql: str) -> str:
with_ = self.sql(expression, "with")
if with_:
sql = f"{with_}{self.sep()}{sql}"
return sql
- def with_sql(self, expression):
+ def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
recursive = "RECURSIVE " if expression.args.get("recursive") else ""
return f"WITH {recursive}{sql}"
- def cte_sql(self, expression):
+ def cte_sql(self, expression: exp.CTE) -> str:
alias = self.sql(expression, "alias")
return f"{alias} AS {self.wrap(expression)}"
- def tablealias_sql(self, expression):
+ def tablealias_sql(self, expression: exp.TableAlias) -> str:
alias = self.sql(expression, "this")
columns = self.expressions(expression, key="columns", flat=True)
columns = f"({columns})" if columns else ""
return f"{alias}{columns}"
- def bitstring_sql(self, expression):
+ def bitstring_sql(self, expression: exp.BitString) -> str:
return self.sql(expression, "this")
- def hexstring_sql(self, expression):
+ def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
- def datatype_sql(self, expression):
+ def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
@@ -455,13 +484,13 @@ class Generator:
)
return f"{type_sql}{nested}"
- def directory_sql(self, expression):
+ def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
row_format = self.sql(expression, "row_format")
row_format = f" {row_format}" if row_format else ""
return f"{local}DIRECTORY {self.sql(expression, 'this')}{row_format}"
- def delete_sql(self, expression):
+ def delete_sql(self, expression: exp.Delete) -> str:
this = self.sql(expression, "this")
using_sql = (
f" USING {self.expressions(expression, 'using', sep=', USING ')}"
@@ -472,7 +501,7 @@ class Generator:
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
- def drop_sql(self, expression):
+ def drop_sql(self, expression: exp.Drop) -> str:
this = self.sql(expression, "this")
kind = expression.args["kind"]
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
@@ -481,46 +510,46 @@ class Generator:
cascade = " CASCADE" if expression.args.get("cascade") else ""
return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
- def except_sql(self, expression):
+ def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.except_op(expression)),
)
- def except_op(self, expression):
+ def except_op(self, expression: exp.Except) -> str:
return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
- def fetch_sql(self, expression):
+ def fetch_sql(self, expression: exp.Fetch) -> str:
direction = expression.args.get("direction")
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
- def filter_sql(self, expression):
+ def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
where = self.sql(expression, "expression")[1:] # where has a leading space
return f"{this} FILTER({where})"
- def hint_sql(self, expression):
+ def hint_sql(self, expression: exp.Hint) -> str:
if self.sql(expression, "this"):
self.unsupported("Hints are not supported")
return ""
- def index_sql(self, expression):
+ def index_sql(self, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
columns = self.sql(expression, "columns")
return f"{this} ON {table} {columns}"
- def identifier_sql(self, expression):
+ def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
- def partition_sql(self, expression):
+ def partition_sql(self, expression: exp.Partition) -> str:
keys = csv(
*[
f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
@@ -529,7 +558,7 @@ class Generator:
)
return f"PARTITION({keys})"
- def properties_sql(self, expression):
+ def properties_sql(self, expression: exp.Properties) -> str:
root_properties = []
with_properties = []
@@ -544,21 +573,21 @@ class Generator:
exp.Properties(expressions=root_properties)
) + self.with_properties(exp.Properties(expressions=with_properties))
- def root_properties(self, properties):
+ def root_properties(self, properties: exp.Properties) -> str:
if properties.expressions:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
- def properties(self, properties, prefix="", sep=", "):
+ def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
+ return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
- def with_properties(self, properties):
- return self.properties(properties, prefix="WITH")
+ def with_properties(self, properties: exp.Properties) -> str:
+ return self.properties(properties, prefix=self.seg("WITH"))
- def property_sql(self, expression):
+ def property_sql(self, expression: exp.Property) -> str:
property_cls = expression.__class__
if property_cls == exp.Property:
return f"{expression.name}={self.sql(expression, 'value')}"
@@ -569,12 +598,12 @@ class Generator:
return f"{property_name}={self.sql(expression, 'this')}"
- def likeproperty_sql(self, expression):
+ def likeproperty_sql(self, expression: exp.LikeProperty) -> str:
options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
options = f" {options}" if options else ""
return f"LIKE {self.sql(expression, 'this')}{options}"
- def insert_sql(self, expression):
+ def insert_sql(self, expression: exp.Insert) -> str:
overwrite = expression.args.get("overwrite")
if isinstance(expression.this, exp.Directory):
@@ -592,19 +621,19 @@ class Generator:
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
return self.prepend_ctes(expression, sql)
- def intersect_sql(self, expression):
+ def intersect_sql(self, expression: exp.Intersect) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.intersect_op(expression)),
)
- def intersect_op(self, expression):
+ def intersect_op(self, expression: exp.Intersect) -> str:
return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
- def introducer_sql(self, expression):
+ def introducer_sql(self, expression: exp.Introducer) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
- def rowformat_sql(self, expression):
+ def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str:
fields = expression.args.get("fields")
fields = f" FIELDS TERMINATED BY {fields}" if fields else ""
escaped = expression.args.get("escaped")
@@ -619,7 +648,7 @@ class Generator:
null = f" NULL DEFINED AS {null}" if null else ""
return f"ROW FORMAT DELIMITED{fields}{escaped}{items}{keys}{lines}{null}"
- def table_sql(self, expression, sep=" AS "):
+ def table_sql(self, expression: exp.Table, sep: str = " AS ") -> str:
table = ".".join(
part
for part in [
@@ -642,7 +671,7 @@ class Generator:
return f"{table}{alias}{laterals}{joins}{pivots}"
- def tablesample_sql(self, expression):
+ def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
this = self.sql(expression.this, "this")
alias = f" AS {self.sql(expression.this, 'alias')}"
@@ -665,7 +694,7 @@ class Generator:
seed = f" SEED ({seed})" if seed else ""
return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){seed}{alias}"
- def pivot_sql(self, expression):
+ def pivot_sql(self, expression: exp.Pivot) -> str:
this = self.sql(expression, "this")
unpivot = expression.args.get("unpivot")
direction = "UNPIVOT" if unpivot else "PIVOT"
@@ -673,10 +702,10 @@ class Generator:
field = self.sql(expression, "field")
return f"{this} {direction}({expressions} FOR {field})"
- def tuple_sql(self, expression):
+ def tuple_sql(self, expression: exp.Tuple) -> str:
return f"({self.expressions(expression, flat=True)})"
- def update_sql(self, expression):
+ def update_sql(self, expression: exp.Update) -> str:
this = self.sql(expression, "this")
set_sql = self.expressions(expression, flat=True)
from_sql = self.sql(expression, "from")
@@ -684,7 +713,7 @@ class Generator:
sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}"
return self.prepend_ctes(expression, sql)
- def values_sql(self, expression):
+ def values_sql(self, expression: exp.Values) -> str:
alias = self.sql(expression, "alias")
args = self.expressions(expression)
if not alias:
@@ -694,19 +723,19 @@ class Generator:
return f"(VALUES{self.seg('')}{args}){alias}"
return f"VALUES{self.seg('')}{args}{alias}"
- def var_sql(self, expression):
+ def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
- def into_sql(self, expression):
+ def into_sql(self, expression: exp.Into) -> str:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
- def from_sql(self, expression):
+ def from_sql(self, expression: exp.From) -> str:
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
- def group_sql(self, expression):
+ def group_sql(self, expression: exp.Group) -> str:
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
grouping_sets = (
@@ -718,11 +747,11 @@ class Generator:
rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
return f"{group_by}{grouping_sets}{cube}{rollup}"
- def having_sql(self, expression):
+ def having_sql(self, expression: exp.Having) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('HAVING')}{self.sep()}{this}"
- def join_sql(self, expression):
+ def join_sql(self, expression: exp.Join) -> str:
op_sql = self.seg(
" ".join(
op
@@ -753,12 +782,12 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
- def lambda_sql(self, expression, arrow_sep="->"):
+ def lambda_sql(self, expression: exp.Lambda, arrow_sep: str = "->") -> str:
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
- def lateral_sql(self, expression):
+ def lateral_sql(self, expression: exp.Lateral) -> str:
this = self.sql(expression, "this")
if isinstance(expression.this, exp.Subquery):
@@ -776,15 +805,15 @@ class Generator:
return f"LATERAL {this}{table}{columns}"
- def limit_sql(self, expression):
+ def limit_sql(self, expression: exp.Limit) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
- def offset_sql(self, expression):
+ def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
- def literal_sql(self, expression):
+ def literal_sql(self, expression: exp.Literal) -> str:
text = expression.this or ""
if expression.is_string:
if self._replace_backslash:
@@ -793,7 +822,7 @@ class Generator:
text = f"{self.quote_start}{text}{self.quote_end}"
return text
- def loaddata_sql(self, expression):
+ def loaddata_sql(self, expression: exp.LoadData) -> str:
local = " LOCAL" if expression.args.get("local") else ""
inpath = f" INPATH {self.sql(expression, 'inpath')}"
overwrite = " OVERWRITE" if expression.args.get("overwrite") else ""
@@ -806,27 +835,27 @@ class Generator:
serde = f" SERDE {serde}" if serde else ""
return f"LOAD DATA{local}{inpath}{overwrite}{this}{partition}{input_format}{serde}"
- def null_sql(self, *_):
+ def null_sql(self, *_) -> str:
return "NULL"
- def boolean_sql(self, expression):
+ def boolean_sql(self, expression: exp.Boolean) -> str:
return "TRUE" if expression.this else "FALSE"
- def order_sql(self, expression, flat=False):
+ def order_sql(self, expression: exp.Order, flat: bool = False) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else this
- return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat)
+ return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) # type: ignore
- def cluster_sql(self, expression):
+ def cluster_sql(self, expression: exp.Cluster) -> str:
return self.op_expressions("CLUSTER BY", expression)
- def distribute_sql(self, expression):
+ def distribute_sql(self, expression: exp.Distribute) -> str:
return self.op_expressions("DISTRIBUTE BY", expression)
- def sort_sql(self, expression):
+ def sort_sql(self, expression: exp.Sort) -> str:
return self.op_expressions("SORT BY", expression)
- def ordered_sql(self, expression):
+ def ordered_sql(self, expression: exp.Ordered) -> str:
desc = expression.args.get("desc")
asc = not desc
@@ -857,7 +886,7 @@ class Generator:
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
- def query_modifiers(self, expression, *sqls):
+ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("laterals", [])],
@@ -876,7 +905,7 @@ class Generator:
sep="",
)
- def select_sql(self, expression):
+ def select_sql(self, expression: exp.Select) -> str:
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@@ -890,36 +919,36 @@ class Generator:
)
return self.prepend_ctes(expression, sql)
- def schema_sql(self, expression):
+ def schema_sql(self, expression: exp.Schema) -> str:
this = self.sql(expression, "this")
this = f"{this} " if this else ""
sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
return f"{this}{sql}"
- def star_sql(self, expression):
+ def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
return f"*{except_}{replace}"
- def structkwarg_sql(self, expression):
+ def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
- def parameter_sql(self, expression):
+ def parameter_sql(self, expression: exp.Parameter) -> str:
return f"@{self.sql(expression, 'this')}"
- def sessionparameter_sql(self, expression):
+ def sessionparameter_sql(self, expression: exp.SessionParameter) -> str:
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"
- def placeholder_sql(self, expression):
+ def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f":{expression.name}" if expression.name else "?"
- def subquery_sql(self, expression):
+ def subquery_sql(self, expression: exp.Subquery) -> str:
alias = self.sql(expression, "alias")
sql = self.query_modifiers(
@@ -931,22 +960,22 @@ class Generator:
return self.prepend_ctes(expression, sql)
- def qualify_sql(self, expression):
+ def qualify_sql(self, expression: exp.Qualify) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
- def union_sql(self, expression):
+ def union_sql(self, expression: exp.Union) -> str:
return self.prepend_ctes(
expression,
self.set_operation(expression, self.union_op(expression)),
)
- def union_op(self, expression):
+ def union_op(self, expression: exp.Union) -> str:
kind = " DISTINCT" if self.EXPLICIT_UNION else ""
kind = kind if expression.args.get("distinct") else " ALL"
return f"UNION{kind}"
- def unnest_sql(self, expression):
+ def unnest_sql(self, expression: exp.Unnest) -> str:
args = self.expressions(expression, flat=True)
alias = expression.args.get("alias")
if alias and self.unnest_column_only:
@@ -958,11 +987,11 @@ class Generator:
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
return f"UNNEST({args}){ordinality}{alias}"
- def where_sql(self, expression):
+ def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('WHERE')}{self.sep()}{this}"
- def window_sql(self, expression):
+ def window_sql(self, expression: exp.Window) -> str:
this = self.sql(expression, "this")
partition = self.expressions(expression, key="partition_by", flat=True)
@@ -988,7 +1017,7 @@ class Generator:
return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
- def window_spec_sql(self, expression):
+ def window_spec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
@@ -997,33 +1026,33 @@ class Generator:
)
return f"{kind} BETWEEN {start} AND {end}"
- def withingroup_sql(self, expression):
+ def withingroup_sql(self, expression: exp.WithinGroup) -> str:
this = self.sql(expression, "this")
- expression = self.sql(expression, "expression")[1:] # order has a leading space
- return f"{this} WITHIN GROUP ({expression})"
+ expression_sql = self.sql(expression, "expression")[1:] # order has a leading space
+ return f"{this} WITHIN GROUP ({expression_sql})"
- def between_sql(self, expression):
+ def between_sql(self, expression: exp.Between) -> str:
this = self.sql(expression, "this")
low = self.sql(expression, "low")
high = self.sql(expression, "high")
return f"{this} BETWEEN {low} AND {high}"
- def bracket_sql(self, expression):
+ def bracket_sql(self, expression: exp.Bracket) -> str:
expressions = apply_index_offset(expression.expressions, self.index_offset)
- expressions = ", ".join(self.sql(e) for e in expressions)
+ expressions_sql = ", ".join(self.sql(e) for e in expressions)
- return f"{self.sql(expression, 'this')}[{expressions}]"
+ return f"{self.sql(expression, 'this')}[{expressions_sql}]"
- def all_sql(self, expression):
+ def all_sql(self, expression: exp.All) -> str:
return f"ALL {self.wrap(expression)}"
- def any_sql(self, expression):
+ def any_sql(self, expression: exp.Any) -> str:
return f"ANY {self.wrap(expression)}"
- def exists_sql(self, expression):
+ def exists_sql(self, expression: exp.Exists) -> str:
return f"EXISTS{self.wrap(expression)}"
- def case_sql(self, expression):
+ def case_sql(self, expression: exp.Case) -> str:
this = self.sql(expression, "this")
statements = [f"CASE {this}" if this else "CASE"]
@@ -1043,17 +1072,17 @@ class Generator:
return " ".join(statements)
- def constraint_sql(self, expression):
+ def constraint_sql(self, expression: exp.Constraint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
- def extract_sql(self, expression):
+ def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
return f"EXTRACT({this} FROM {expression_sql})"
- def trim_sql(self, expression):
+ def trim_sql(self, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
@@ -1064,16 +1093,16 @@ class Generator:
else:
return f"TRIM({target})"
- def concat_sql(self, expression):
+ def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
return self.sql(expression.expressions[0])
return self.function_fallback_sql(expression)
- def check_sql(self, expression):
+ def check_sql(self, expression: exp.Check) -> str:
this = self.sql(expression, key="this")
return f"CHECK ({this})"
- def foreignkey_sql(self, expression):
+ def foreignkey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
reference = self.sql(expression, "reference")
reference = f" {reference}" if reference else ""
@@ -1083,16 +1112,16 @@ class Generator:
update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
- def unique_sql(self, expression):
+ def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})"
- def if_sql(self, expression):
+ def if_sql(self, expression: exp.If) -> str:
return self.case_sql(
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
- def in_sql(self, expression):
+ def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
field = expression.args.get("field")
@@ -1106,24 +1135,24 @@ class Generator:
in_sql = f"({self.expressions(expression, flat=True)})"
return f"{self.sql(expression, 'this')} IN {in_sql}"
- def in_unnest_op(self, unnest):
+ def in_unnest_op(self, unnest: exp.Unnest) -> str:
return f"(SELECT {self.sql(unnest)})"
- def interval_sql(self, expression):
+ def interval_sql(self, expression: exp.Interval) -> str:
unit = self.sql(expression, "unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL {self.sql(expression, 'this')}{unit}"
- def reference_sql(self, expression):
+ def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"REFERENCES {this}({expressions})"
- def anonymous_sql(self, expression):
+ def anonymous_sql(self, expression: exp.Anonymous) -> str:
args = self.format_args(*expression.expressions)
return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
- def paren_sql(self, expression):
+ def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
sql = self.wrap(expression)
else:
@@ -1132,35 +1161,35 @@ class Generator:
return self.prepend_ctes(expression, sql)
- def neg_sql(self, expression):
+ def neg_sql(self, expression: exp.Neg) -> str:
# This makes sure we don't convert "- - 5" to "--5", which is a comment
this_sql = self.sql(expression, "this")
sep = " " if this_sql[0] == "-" else ""
return f"-{sep}{this_sql}"
- def not_sql(self, expression):
+ def not_sql(self, expression: exp.Not) -> str:
return f"NOT {self.sql(expression, 'this')}"
- def alias_sql(self, expression):
+ def alias_sql(self, expression: exp.Alias) -> str:
to_sql = self.sql(expression, "alias")
to_sql = f" AS {to_sql}" if to_sql else ""
return f"{self.sql(expression, 'this')}{to_sql}"
- def aliases_sql(self, expression):
+ def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
- def attimezone_sql(self, expression):
+ def attimezone_sql(self, expression: exp.AtTimeZone) -> str:
this = self.sql(expression, "this")
zone = self.sql(expression, "zone")
return f"{this} AT TIME ZONE {zone}"
- def add_sql(self, expression):
+ def add_sql(self, expression: exp.Add) -> str:
return self.binary(expression, "+")
- def and_sql(self, expression):
+ def and_sql(self, expression: exp.And) -> str:
return self.connector_sql(expression, "AND")
- def connector_sql(self, expression, op):
+ def connector_sql(self, expression: exp.Connector, op: str) -> str:
if not self.pretty:
return self.binary(expression, op)
@@ -1168,53 +1197,53 @@ class Generator:
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
return f"{sep}{op} ".join(sqls)
- def bitwiseand_sql(self, expression):
+ def bitwiseand_sql(self, expression: exp.BitwiseAnd) -> str:
return self.binary(expression, "&")
- def bitwiseleftshift_sql(self, expression):
+ def bitwiseleftshift_sql(self, expression: exp.BitwiseLeftShift) -> str:
return self.binary(expression, "<<")
- def bitwisenot_sql(self, expression):
+ def bitwisenot_sql(self, expression: exp.BitwiseNot) -> str:
return f"~{self.sql(expression, 'this')}"
- def bitwiseor_sql(self, expression):
+ def bitwiseor_sql(self, expression: exp.BitwiseOr) -> str:
return self.binary(expression, "|")
- def bitwiserightshift_sql(self, expression):
+ def bitwiserightshift_sql(self, expression: exp.BitwiseRightShift) -> str:
return self.binary(expression, ">>")
- def bitwisexor_sql(self, expression):
+ def bitwisexor_sql(self, expression: exp.BitwiseXor) -> str:
return self.binary(expression, "^")
- def cast_sql(self, expression):
+ def cast_sql(self, expression: exp.Cast) -> str:
return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- def currentdate_sql(self, expression):
+ def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
- def collate_sql(self, expression):
+ def collate_sql(self, expression: exp.Collate) -> str:
return self.binary(expression, "COLLATE")
- def command_sql(self, expression):
+ def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
- def transaction_sql(self, *_):
+ def transaction_sql(self, *_) -> str:
return "BEGIN"
- def commit_sql(self, expression):
+ def commit_sql(self, expression: exp.Commit) -> str:
chain = expression.args.get("chain")
if chain is not None:
chain = " AND CHAIN" if chain else " AND NO CHAIN"
return f"COMMIT{chain or ''}"
- def rollback_sql(self, expression):
+ def rollback_sql(self, expression: exp.Rollback) -> str:
savepoint = expression.args.get("savepoint")
savepoint = f" TO {savepoint}" if savepoint else ""
return f"ROLLBACK{savepoint}"
- def distinct_sql(self, expression):
+ def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1222,13 +1251,13 @@ class Generator:
on = f" ON {on}" if on else ""
return f"DISTINCT{this}{on}"
- def ignorenulls_sql(self, expression):
+ def ignorenulls_sql(self, expression: exp.IgnoreNulls) -> str:
return f"{self.sql(expression, 'this')} IGNORE NULLS"
- def respectnulls_sql(self, expression):
+ def respectnulls_sql(self, expression: exp.RespectNulls) -> str:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
- def intdiv_sql(self, expression):
+ def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
this=exp.Div(this=expression.this, expression=expression.expression),
@@ -1236,79 +1265,79 @@ class Generator:
)
)
- def dpipe_sql(self, expression):
+ def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
- def div_sql(self, expression):
+ def div_sql(self, expression: exp.Div) -> str:
return self.binary(expression, "/")
- def distance_sql(self, expression):
+ def distance_sql(self, expression: exp.Distance) -> str:
return self.binary(expression, "<->")
- def dot_sql(self, expression):
+ def dot_sql(self, expression: exp.Dot) -> str:
return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
- def eq_sql(self, expression):
+ def eq_sql(self, expression: exp.EQ) -> str:
return self.binary(expression, "=")
- def escape_sql(self, expression):
+ def escape_sql(self, expression: exp.Escape) -> str:
return self.binary(expression, "ESCAPE")
- def gt_sql(self, expression):
+ def gt_sql(self, expression: exp.GT) -> str:
return self.binary(expression, ">")
- def gte_sql(self, expression):
+ def gte_sql(self, expression: exp.GTE) -> str:
return self.binary(expression, ">=")
- def ilike_sql(self, expression):
+ def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
- def is_sql(self, expression):
+ def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
- def like_sql(self, expression):
+ def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
- def similarto_sql(self, expression):
+ def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")
- def lt_sql(self, expression):
+ def lt_sql(self, expression: exp.LT) -> str:
return self.binary(expression, "<")
- def lte_sql(self, expression):
+ def lte_sql(self, expression: exp.LTE) -> str:
return self.binary(expression, "<=")
- def mod_sql(self, expression):
+ def mod_sql(self, expression: exp.Mod) -> str:
return self.binary(expression, "%")
- def mul_sql(self, expression):
+ def mul_sql(self, expression: exp.Mul) -> str:
return self.binary(expression, "*")
- def neq_sql(self, expression):
+ def neq_sql(self, expression: exp.NEQ) -> str:
return self.binary(expression, "<>")
- def nullsafeeq_sql(self, expression):
+ def nullsafeeq_sql(self, expression: exp.NullSafeEQ) -> str:
return self.binary(expression, "IS NOT DISTINCT FROM")
- def nullsafeneq_sql(self, expression):
+ def nullsafeneq_sql(self, expression: exp.NullSafeNEQ) -> str:
return self.binary(expression, "IS DISTINCT FROM")
- def or_sql(self, expression):
+ def or_sql(self, expression: exp.Or) -> str:
return self.connector_sql(expression, "OR")
- def sub_sql(self, expression):
+ def sub_sql(self, expression: exp.Sub) -> str:
return self.binary(expression, "-")
- def trycast_sql(self, expression):
+ def trycast_sql(self, expression: exp.TryCast) -> str:
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
- def use_sql(self, expression):
+ def use_sql(self, expression: exp.Use) -> str:
return f"USE {self.sql(expression, 'this')}"
- def binary(self, expression, op):
+ def binary(self, expression: exp.Binary, op: str) -> str:
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
- def function_fallback_sql(self, expression):
+ def function_fallback_sql(self, expression: exp.Func) -> str:
args = []
for arg_value in expression.args.values():
if isinstance(arg_value, list):
@@ -1319,19 +1348,26 @@ class Generator:
return f"{self.normalize_func(expression.sql_name())}({self.format_args(*args)})"
- def format_args(self, *args):
- args = tuple(self.sql(arg) for arg in args if arg is not None)
- if self.pretty and self.text_width(args) > self._max_text_width:
- return self.indent("\n" + f",\n".join(args) + "\n", skip_first=True, skip_last=True)
- return ", ".join(args)
+ def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
+ arg_sqls = tuple(self.sql(arg) for arg in args if arg is not None)
+ if self.pretty and self.text_width(arg_sqls) > self._max_text_width:
+ return self.indent("\n" + f",\n".join(arg_sqls) + "\n", skip_first=True, skip_last=True)
+ return ", ".join(arg_sqls)
- def text_width(self, args):
+ def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)
- def format_time(self, expression):
+ def format_time(self, expression: exp.Expression) -> t.Optional[str]:
return format_time(self.sql(expression, "format"), self.time_mapping, self.time_trie)
- def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
+ def expressions(
+ self,
+ expression: exp.Expression,
+ key: t.Optional[str] = None,
+ flat: bool = False,
+ indent: bool = True,
+ sep: str = ", ",
+ ) -> str:
expressions = expression.args.get(key or "expressions")
if not expressions:
@@ -1359,45 +1395,67 @@ class Generator:
else:
result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
- result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
- return self.indent(result_sqls, skip_first=False) if indent else result_sqls
+ result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
+ return self.indent(result_sql, skip_first=False) if indent else result_sql
- def op_expressions(self, op, expression, flat=False):
+ def op_expressions(self, op: str, expression: exp.Expression, flat: bool = False) -> str:
expressions_sql = self.expressions(expression, flat=flat)
if flat:
return f"{op} {expressions_sql}"
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
- def naked_property(self, expression):
+ def naked_property(self, expression: exp.Property) -> str:
property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
if not property_name:
self.unsupported(f"Unsupported property {expression.__class__.__name__}")
return f"{property_name} {self.sql(expression, 'this')}"
- def set_operation(self, expression, op):
+ def set_operation(self, expression: exp.Expression, op: str) -> str:
this = self.sql(expression, "this")
op = self.seg(op)
return self.query_modifiers(
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
- def token_sql(self, token_type):
+ def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name)
- def userdefinedfunction_sql(self, expression):
+ def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
return f"{this}({expressions})"
- def userdefinedfunctionkwarg_sql(self, expression):
+ def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
return f"{this} {kind}"
- def joinhint_sql(self, expression):
+ def joinhint_sql(self, expression: exp.JoinHint) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
- def kwarg_sql(self, expression):
+ def kwarg_sql(self, expression: exp.Kwarg) -> str:
return self.binary(expression, "=>")
+
+ def when_sql(self, expression: exp.When) -> str:
+ this = self.sql(expression, "this")
+ then_expression = expression.args.get("then")
+ if isinstance(then_expression, exp.Insert):
+ then = f"INSERT {self.sql(then_expression, 'this')}"
+ if "expression" in then_expression.args:
+ then += f" VALUES {self.sql(then_expression, 'expression')}"
+ elif isinstance(then_expression, exp.Update):
+ if isinstance(then_expression.args.get("expressions"), exp.Star):
+ then = f"UPDATE {self.sql(then_expression, 'expressions')}"
+ else:
+ then = f"UPDATE SET {self.expressions(then_expression, flat=True)}"
+ else:
+ then = self.sql(then_expression)
+ return f"WHEN {this} THEN {then}"
+
+ def merge_sql(self, expression: exp.Merge) -> str:
+ this = self.sql(expression, "this")
+ using = f"USING {self.sql(expression, 'using')}"
+ on = f"ON {self.sql(expression, 'on')}"
+ return f"MERGE INTO {this} {using} {on} {self.expressions(expression, sep=' ')}"
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 8c5808d..ed37e6c 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -385,3 +385,11 @@ def dict_depth(d: t.Dict) -> int:
except StopIteration:
# d.values() returns an empty sequence
return 1
+
+
+def first(it: t.Iterable[T]) -> T:
+ """Returns the first element from an iterable.
+
+ Useful for sets.
+ """
+ return next(i for i in it)
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 191ea52..be17f15 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -14,7 +14,7 @@ def annotate_types(expression, schema=None, annotators=None, coerces_to=None):
>>> schema = {"y": {"cola": "SMALLINT"}}
>>> sql = "SELECT x.cola + 2.5 AS cola FROM (SELECT y.cola AS cola FROM y AS y) AS x"
>>> annotated_expr = annotate_types(sqlglot.parse_one(sql), schema=schema)
- >>> annotated_expr.expressions[0].type # Get the type of "x.cola + 2.5 AS cola"
+ >>> annotated_expr.expressions[0].type.this # Get the type of "x.cola + 2.5 AS cola"
<Type.DOUBLE: 'DOUBLE'>
Args:
@@ -41,9 +41,12 @@ class TypeAnnotator:
expr_type: lambda self, expr: self._annotate_binary(expr)
for expr_type in subclasses(exp.__name__, exp.Binary)
},
- exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"].this),
- exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr.this),
+ exp.Cast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
+ exp.TryCast: lambda self, expr: self._annotate_with_type(expr, expr.args["to"]),
+ exp.DataType: lambda self, expr: self._annotate_with_type(expr, expr),
exp.Alias: lambda self, expr: self._annotate_unary(expr),
+ exp.Between: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.In: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Literal: lambda self, expr: self._annotate_literal(expr),
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
@@ -52,6 +55,9 @@ class TypeAnnotator:
expr, exp.DataType.Type.BIGINT
),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.Min: lambda self, expr: self._annotate_by_args(expr, "this"),
+ exp.Max: lambda self, expr: self._annotate_by_args(expr, "this"),
+ exp.Sum: lambda self, expr: self._annotate_by_args(expr, "this", promote=True),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
@@ -263,10 +269,10 @@ class TypeAnnotator:
}
# First annotate the current scope's column references
for col in scope.columns:
- source = scope.sources[col.table]
+ source = scope.sources.get(col.table)
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
- else:
+ elif source:
col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
@@ -280,6 +286,7 @@ class TypeAnnotator:
return expression # We've already inferred the expression's type
annotator = self.annotators.get(expression.__class__)
+
return (
annotator(self, expression)
if annotator
@@ -295,18 +302,23 @@ class TypeAnnotator:
def _maybe_coerce(self, type1, type2):
# We propagate the NULL / UNKNOWN types upwards if found
+ if isinstance(type1, exp.DataType):
+ type1 = type1.this
+ if isinstance(type2, exp.DataType):
+ type2 = type2.this
+
if exp.DataType.Type.NULL in (type1, type2):
return exp.DataType.Type.NULL
if exp.DataType.Type.UNKNOWN in (type1, type2):
return exp.DataType.Type.UNKNOWN
- return type2 if type2 in self.coerces_to[type1] else type1
+ return type2 if type2 in self.coerces_to.get(type1, {}) else type1
def _annotate_binary(self, expression):
self._annotate_args(expression)
- left_type = expression.left.type
- right_type = expression.right.type
+ left_type = expression.left.type.this
+ right_type = expression.right.type.this
if isinstance(expression, (exp.And, exp.Or)):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
@@ -348,7 +360,7 @@ class TypeAnnotator:
expression.type = target_type
return self._annotate_args(expression)
- def _annotate_by_args(self, expression, *args):
+ def _annotate_by_args(self, expression, *args, promote=False):
self._annotate_args(expression)
expressions = []
for arg in args:
@@ -360,4 +372,11 @@ class TypeAnnotator:
last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
expression.type = last_datatype or exp.DataType.Type.UNKNOWN
+
+ if promote:
+ if expression.type.this in exp.DataType.INTEGER_TYPES:
+ expression.type = exp.DataType.Type.BIGINT
+ elif expression.type.this in exp.DataType.FLOAT_TYPES:
+ expression.type = exp.DataType.Type.DOUBLE
+
return expression
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
index 9b3d98a..33529a5 100644
--- a/sqlglot/optimizer/canonicalize.py
+++ b/sqlglot/optimizer/canonicalize.py
@@ -13,13 +13,16 @@ def canonicalize(expression: exp.Expression) -> exp.Expression:
expression: The expression to canonicalize.
"""
exp.replace_children(expression, canonicalize)
+
expression = add_text_to_concat(expression)
expression = coerce_type(expression)
+ expression = remove_redundant_casts(expression)
+
return expression
def add_text_to_concat(node: exp.Expression) -> exp.Expression:
- if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ if isinstance(node, exp.Add) and node.type and node.type.this in exp.DataType.TEXT_TYPES:
node = exp.Concat(this=node.this, expression=node.expression)
return node
@@ -30,14 +33,30 @@ def coerce_type(node: exp.Expression) -> exp.Expression:
elif isinstance(node, exp.Between):
_coerce_date(node.this, node.args["low"])
elif isinstance(node, exp.Extract):
- if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ if node.expression.type.this not in exp.DataType.TEMPORAL_TYPES:
_replace_cast(node.expression, "datetime")
return node
+def remove_redundant_casts(expression: exp.Expression) -> exp.Expression:
+ if (
+ isinstance(expression, exp.Cast)
+ and expression.to.type
+ and expression.this.type
+ and expression.to.type.this == expression.this.type.this
+ ):
+ return expression.this
+ return expression
+
+
def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
for a, b in itertools.permutations([a, b]):
- if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ if (
+ a.type
+ and a.type.this == exp.DataType.Type.DATE
+ and b.type
+ and b.type.this != exp.DataType.Type.DATE
+ ):
_replace_cast(b, "date")
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c432c59..c0719f2 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -7,7 +7,7 @@ from decimal import Decimal
from sqlglot import exp
from sqlglot.expressions import FALSE, NULL, TRUE
from sqlglot.generator import Generator
-from sqlglot.helper import while_changing
+from sqlglot.helper import first, while_changing
GENERATOR = Generator(normalize=True, identify=True)
@@ -30,6 +30,7 @@ def simplify(expression):
def _simplify(expression, root=True):
node = expression
+ node = rewrite_between(node)
node = uniq_sort(node)
node = absorb_and_eliminate(node)
exp.replace_children(node, lambda e: _simplify(e, False))
@@ -49,6 +50,19 @@ def simplify(expression):
return expression
+def rewrite_between(expression: exp.Expression) -> exp.Expression:
+ """Rewrite x between y and z to x >= y AND x <= z.
+
+ This is done because comparison simplification is only done on lt/lte/gt/gte.
+ """
+ if isinstance(expression, exp.Between):
+ return exp.and_(
+ exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
+ exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
+ )
+ return expression
+
+
def simplify_not(expression):
"""
Demorgan's Law
@@ -57,7 +71,7 @@ def simplify_not(expression):
"""
if isinstance(expression, exp.Not):
if isinstance(expression.this, exp.Null):
- return NULL
+ return exp.null()
if isinstance(expression.this, exp.Paren):
condition = expression.this.unnest()
if isinstance(condition, exp.And):
@@ -65,11 +79,11 @@ def simplify_not(expression):
if isinstance(condition, exp.Or):
return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
if isinstance(condition, exp.Null):
- return NULL
+ return exp.null()
if always_true(expression.this):
- return FALSE
+ return exp.false()
if expression.this == FALSE:
- return TRUE
+ return exp.true()
if isinstance(expression.this, exp.Not):
# double negation
# NOT NOT x -> x
@@ -91,40 +105,119 @@ def flatten(expression):
def simplify_connectors(expression):
- if isinstance(expression, exp.Connector):
- left = expression.left
- right = expression.right
-
- if left == right:
- return left
-
- if isinstance(expression, exp.And):
- if FALSE in (left, right):
- return FALSE
- if NULL in (left, right):
- return NULL
- if always_true(left) and always_true(right):
- return TRUE
- if always_true(left):
- return right
- if always_true(right):
- return left
- elif isinstance(expression, exp.Or):
- if always_true(left) or always_true(right):
- return TRUE
- if left == FALSE and right == FALSE:
- return FALSE
- if (
- (left == NULL and right == NULL)
- or (left == NULL and right == FALSE)
- or (left == FALSE and right == NULL)
- ):
- return NULL
- if left == FALSE:
- return right
- if right == FALSE:
+ def _simplify_connectors(expression, left, right):
+ if isinstance(expression, exp.Connector):
+ if left == right:
return left
- return expression
+ if isinstance(expression, exp.And):
+ if FALSE in (left, right):
+ return exp.false()
+ if NULL in (left, right):
+ return exp.null()
+ if always_true(left) and always_true(right):
+ return exp.true()
+ if always_true(left):
+ return right
+ if always_true(right):
+ return left
+ return _simplify_comparison(expression, left, right)
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return exp.true()
+ if left == FALSE and right == FALSE:
+ return exp.false()
+ if (
+ (left == NULL and right == NULL)
+ or (left == NULL and right == FALSE)
+ or (left == FALSE and right == NULL)
+ ):
+ return exp.null()
+ if left == FALSE:
+ return right
+ if right == FALSE:
+ return left
+ return _simplify_comparison(expression, left, right, or_=True)
+ return None
+
+ return _flat_simplify(expression, _simplify_connectors)
+
+
+LT_LTE = (exp.LT, exp.LTE)
+GT_GTE = (exp.GT, exp.GTE)
+
+COMPARISONS = (
+ *LT_LTE,
+ *GT_GTE,
+ exp.EQ,
+ exp.NEQ,
+)
+
+INVERSE_COMPARISONS = {
+ exp.LT: exp.GT,
+ exp.GT: exp.LT,
+ exp.LTE: exp.GTE,
+ exp.GTE: exp.LTE,
+}
+
+
+def _simplify_comparison(expression, left, right, or_=False):
+ if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
+ ll, lr = left.args.values()
+ rl, rr = right.args.values()
+
+ largs = {ll, lr}
+ rargs = {rl, rr}
+
+ matching = largs & rargs
+ columns = {m for m in matching if isinstance(m, exp.Column)}
+
+ if matching and columns:
+ try:
+ l = first(largs - columns)
+ r = first(rargs - columns)
+ except StopIteration:
+ return expression
+
+ # make sure the comparison is always of the form x > 1 instead of 1 < x
+ if left.__class__ in INVERSE_COMPARISONS and l == ll:
+ left = INVERSE_COMPARISONS[left.__class__](this=lr, expression=ll)
+ if right.__class__ in INVERSE_COMPARISONS and r == rl:
+ right = INVERSE_COMPARISONS[right.__class__](this=rr, expression=rl)
+
+ if l.is_number and r.is_number:
+ l = float(l.name)
+ r = float(r.name)
+ elif l.is_string and r.is_string:
+ l = l.name
+ r = r.name
+ else:
+ return None
+
+ for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
+ if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
+ return left if (av > bv if or_ else av <= bv) else right
+ if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
+ return left if (av < bv if or_ else av >= bv) else right
+
+ # we can't ever shortcut to true because the column could be null
+ if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
+ if not or_ and av <= bv:
+ return exp.false()
+ elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
+ if not or_ and av >= bv:
+ return exp.false()
+ elif isinstance(a, exp.EQ):
+ if isinstance(b, exp.LT):
+ return exp.false() if av >= bv else a
+ if isinstance(b, exp.LTE):
+ return exp.false() if av > bv else a
+ if isinstance(b, exp.GT):
+ return exp.false() if av <= bv else a
+ if isinstance(b, exp.GTE):
+ return exp.false() if av < bv else a
+ if isinstance(b, exp.NEQ):
+ return exp.false() if av == bv else a
+ return None
def remove_compliments(expression):
@@ -135,7 +228,7 @@ def remove_compliments(expression):
A OR NOT A -> TRUE
"""
if isinstance(expression, exp.Connector):
- compliment = FALSE if isinstance(expression, exp.And) else TRUE
+ compliment = exp.false() if isinstance(expression, exp.And) else exp.true()
for a, b in itertools.permutations(expression.flatten(), 2):
if is_complement(a, b):
@@ -211,27 +304,7 @@ def absorb_and_eliminate(expression):
def simplify_literals(expression):
if isinstance(expression, exp.Binary):
- operands = []
- queue = deque(expression.flatten(unnest=False))
- size = len(queue)
-
- while queue:
- a = queue.popleft()
-
- for b in queue:
- result = _simplify_binary(expression, a, b)
-
- if result:
- queue.remove(b)
- queue.append(result)
- break
- else:
- operands.append(a)
-
- if len(operands) < size:
- return functools.reduce(
- lambda a, b: expression.__class__(this=a, expression=b), operands
- )
+ return _flat_simplify(expression, _simplify_binary)
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -254,20 +327,13 @@ def _simplify_binary(expression, a, b):
if c == NULL:
if isinstance(a, exp.Literal):
- return TRUE if not_ else FALSE
+ return exp.true() if not_ else exp.false()
if a == NULL:
- return FALSE if not_ else TRUE
- elif isinstance(expression, exp.NullSafeEQ):
- if a == b:
- return TRUE
- elif isinstance(expression, exp.NullSafeNEQ):
- if a == b:
- return FALSE
+ return exp.false() if not_ else exp.true()
+ elif isinstance(expression, (exp.NullSafeEQ, exp.NullSafeNEQ)):
+ return None
elif NULL in (a, b):
- return NULL
-
- if isinstance(expression, exp.EQ) and a == b:
- return TRUE
+ return exp.null()
if a.is_number and b.is_number:
a = int(a.name) if a.is_int else Decimal(a.name)
@@ -388,4 +454,27 @@ def date_literal(date):
def boolean_literal(condition):
- return TRUE if condition else FALSE
+ return exp.true() if condition else exp.false()
+
+
+def _flat_simplify(expression, simplifier):
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
+
+ while queue:
+ a = queue.popleft()
+
+ for b in queue:
+ result = simplifier(expression, a, b)
+
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
+
+ if len(operands) < size:
+ return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ return expression
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index bdf0d2d..55ab453 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -185,6 +185,7 @@ class Parser(metaclass=_Parser):
TokenType.LOCAL,
TokenType.LOCATION,
TokenType.MATERIALIZED,
+ TokenType.MERGE,
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
@@ -211,7 +212,6 @@ class Parser(metaclass=_Parser):
TokenType.TABLE,
TokenType.TABLE_FORMAT,
TokenType.TEMPORARY,
- TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
TokenType.TRUE,
@@ -229,6 +229,8 @@ class Parser(metaclass=_Parser):
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
+ UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET}
+
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
@@ -241,6 +243,7 @@ class Parser(metaclass=_Parser):
TokenType.FORMAT,
TokenType.IDENTIFIER,
TokenType.ISNULL,
+ TokenType.MERGE,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
TokenType.REPLACE,
@@ -407,6 +410,7 @@ class Parser(metaclass=_Parser):
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
TokenType.END: lambda self: self._parse_commit_or_rollback(),
TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
+ TokenType.MERGE: lambda self: self._parse_merge(),
}
UNARY_PARSERS = {
@@ -474,6 +478,7 @@ class Parser(metaclass=_Parser):
TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.LIKE: lambda self: self._parse_create_like(),
TokenType.RETURNS: lambda self: self._parse_returns(),
+ TokenType.ROW: lambda self: self._parse_row(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
@@ -495,6 +500,8 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
),
+ TokenType.WITH: lambda self: self._parse_wrapped_csv(self._parse_property),
+ TokenType.PROPERTIES: lambda self: self._parse_wrapped_csv(self._parse_property),
}
CONSTRAINT_PARSERS = {
@@ -802,7 +809,8 @@ class Parser(metaclass=_Parser):
def _parse_create(self):
replace = self._match_pair(TokenType.OR, TokenType.REPLACE)
temporary = self._match(TokenType.TEMPORARY)
- transient = self._match(TokenType.TRANSIENT)
+ transient = self._match_text_seq("TRANSIENT")
+ external = self._match_text_seq("EXTERNAL")
unique = self._match(TokenType.UNIQUE)
materialized = self._match(TokenType.MATERIALIZED)
@@ -846,6 +854,7 @@ class Parser(metaclass=_Parser):
properties=properties,
temporary=temporary,
transient=transient,
+ external=external,
replace=replace,
unique=unique,
materialized=materialized,
@@ -861,8 +870,12 @@ class Parser(metaclass=_Parser):
if self._match_pair(TokenType.COMPOUND, TokenType.SORTKEY):
return self._parse_sortkey(compound=True)
- if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
- key = self._parse_var()
+ assignment = self._match_pair(
+ TokenType.VAR, TokenType.EQ, advance=False
+ ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False)
+
+ if assignment:
+ key = self._parse_var() or self._parse_string()
self._match(TokenType.EQ)
return self.expression(exp.Property, this=key, value=self._parse_column())
@@ -871,7 +884,10 @@ class Parser(metaclass=_Parser):
def _parse_property_assignment(self, exp_class):
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
- return self.expression(exp_class, this=self._parse_var_or_string() or self._parse_number())
+ return self.expression(
+ exp_class,
+ this=self._parse_var_or_string() or self._parse_number() or self._parse_id_var(),
+ )
def _parse_partitioned_by(self):
self._match(TokenType.EQ)
@@ -881,7 +897,7 @@ class Parser(metaclass=_Parser):
)
def _parse_distkey(self):
- return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_var))
+ return self.expression(exp.DistKeyProperty, this=self._parse_wrapped(self._parse_id_var))
def _parse_create_like(self):
table = self._parse_table(schema=True)
@@ -898,7 +914,7 @@ class Parser(metaclass=_Parser):
def _parse_sortkey(self, compound=False):
return self.expression(
- exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_var), compound=compound
+ exp.SortKeyProperty, this=self._parse_wrapped_csv(self._parse_id_var), compound=compound
)
def _parse_character_set(self, default=False):
@@ -929,23 +945,11 @@ class Parser(metaclass=_Parser):
properties = []
while True:
- if self._match(TokenType.WITH):
- properties.extend(self._parse_wrapped_csv(self._parse_property))
- elif self._match(TokenType.PROPERTIES):
- properties.extend(
- self._parse_wrapped_csv(
- lambda: self.expression(
- exp.Property,
- this=self._parse_string(),
- value=self._match(TokenType.EQ) and self._parse_string(),
- )
- )
- )
- else:
- identified_property = self._parse_property()
- if not identified_property:
- break
- properties.append(identified_property)
+ identified_property = self._parse_property()
+ if not identified_property:
+ break
+ for p in ensure_collection(identified_property):
+ properties.append(p)
if properties:
return self.expression(exp.Properties, expressions=properties)
@@ -963,7 +967,7 @@ class Parser(metaclass=_Parser):
exp.Directory,
this=self._parse_var_or_string(),
local=local,
- row_format=self._parse_row_format(),
+ row_format=self._parse_row_format(match_row=True),
)
else:
self._match(TokenType.INTO)
@@ -978,10 +982,18 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
)
- def _parse_row_format(self):
- if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
+ def _parse_row(self):
+ if not self._match(TokenType.FORMAT):
+ return None
+ return self._parse_row_format()
+
+ def _parse_row_format(self, match_row=False):
+ if match_row and not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
+ if self._match_text_seq("SERDE"):
+ return self.expression(exp.RowFormatSerdeProperty, this=self._parse_string())
+
self._match_text_seq("DELIMITED")
kwargs = {}
@@ -998,7 +1010,7 @@ class Parser(metaclass=_Parser):
kwargs["lines"] = self._parse_string()
if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
- return self.expression(exp.RowFormat, **kwargs)
+ return self.expression(exp.RowFormatDelimitedProperty, **kwargs)
def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
@@ -1032,7 +1044,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Update,
**{
- "this": self._parse_table(schema=True),
+ "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
"expressions": self._match(TokenType.SET) and self._parse_csv(self._parse_equality),
"from": self._parse_from(),
"where": self._parse_where(),
@@ -1183,9 +1195,11 @@ class Parser(metaclass=_Parser):
alias=alias,
)
- def _parse_table_alias(self):
+ def _parse_table_alias(self, alias_tokens=None):
any_token = self._match(TokenType.ALIAS)
- alias = self._parse_id_var(any_token=any_token, tokens=self.TABLE_ALIAS_TOKENS)
+ alias = self._parse_id_var(
+ any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS
+ )
columns = None
if self._match(TokenType.L_PAREN):
@@ -1337,7 +1351,7 @@ class Parser(metaclass=_Parser):
columns=self._parse_expression(),
)
- def _parse_table(self, schema=False):
+ def _parse_table(self, schema=False, alias_tokens=None):
lateral = self._parse_lateral()
if lateral:
@@ -1372,7 +1386,7 @@ class Parser(metaclass=_Parser):
table = self._parse_id_var()
if not table:
- self.raise_error("Expected table name")
+ self.raise_error(f"Expected table name but got {self._curr}")
this = self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
@@ -1384,7 +1398,7 @@ class Parser(metaclass=_Parser):
if self.alias_post_tablesample:
table_sample = self._parse_table_sample()
- alias = self._parse_table_alias()
+ alias = self._parse_table_alias(alias_tokens=alias_tokens or self.TABLE_ALIAS_TOKENS)
if alias:
this.set("alias", alias)
@@ -2092,10 +2106,14 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
+ elif self._match(TokenType.ENCODE):
+ kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
+ elif self._match(TokenType.NULL):
+ kind = exp.NotNullColumnConstraint(allow_null=True)
elif self._match(TokenType.SCHEMA_COMMENT):
kind = self.expression(exp.CommentColumnConstraint, this=self._parse_string())
elif self._match(TokenType.PRIMARY_KEY):
@@ -2234,7 +2252,7 @@ class Parser(metaclass=_Parser):
return self._parse_window(this)
def _parse_extract(self):
- this = self._parse_var() or self._parse_type()
+ this = self._parse_function() or self._parse_var() or self._parse_type()
if self._match(TokenType.FROM):
return self.expression(exp.Extract, this=this, expression=self._parse_bitwise())
@@ -2635,6 +2653,54 @@ class Parser(metaclass=_Parser):
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item()
+ def _parse_merge(self):
+ self._match(TokenType.INTO)
+ target = self._parse_table(schema=True)
+
+ self._match(TokenType.USING)
+ using = self._parse_table()
+
+ self._match(TokenType.ON)
+ on = self._parse_conjunction()
+
+ whens = []
+ while self._match(TokenType.WHEN):
+ this = self._parse_conjunction()
+ self._match(TokenType.THEN)
+
+ if self._match(TokenType.INSERT):
+ _this = self._parse_star()
+ if _this:
+ then = self.expression(exp.Insert, this=_this)
+ else:
+ then = self.expression(
+ exp.Insert,
+ this=self._parse_value(),
+ expression=self._match(TokenType.VALUES) and self._parse_value(),
+ )
+ elif self._match(TokenType.UPDATE):
+ expressions = self._parse_star()
+ if expressions:
+ then = self.expression(exp.Update, expressions=expressions)
+ else:
+ then = self.expression(
+ exp.Update,
+ expressions=self._match(TokenType.SET)
+ and self._parse_csv(self._parse_equality),
+ )
+ elif self._match(TokenType.DELETE):
+ then = self.expression(exp.Var, this=self._prev.text)
+
+ whens.append(self.expression(exp.When, this=this, then=then))
+
+ return self.expression(
+ exp.Merge,
+ this=target,
+ using=using,
+ on=on,
+ expressions=whens,
+ )
+
def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index f6f303b..8a264a2 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -47,7 +47,7 @@ class Schema(abc.ABC):
"""
@abc.abstractmethod
- def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
+ def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType:
"""
Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
@@ -160,8 +160,8 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
super().__init__(schema)
self.visible = visible or {}
self.dialect = dialect
- self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
- "STR": exp.DataType.Type.TEXT,
+ self._type_mapping_cache: t.Dict[str, exp.DataType] = {
+ "STR": exp.DataType.build("text"),
}
@classmethod
@@ -231,18 +231,18 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
- def get_column_type(
- self, table: exp.Table | str, column: exp.Column | str
- ) -> exp.DataType.Type:
+ def get_column_type(self, table: exp.Table | str, column: exp.Column | str) -> exp.DataType:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
- table_schema = self.find(table_)
- schema_type = table_schema.get(column_name).upper() # type: ignore
- return self._convert_type(schema_type)
+ table_schema = self.find(table_, raise_on_missing=False)
+ if table_schema:
+ schema_type = table_schema.get(column_name).upper() # type: ignore
+ return self._convert_type(schema_type)
+ return exp.DataType(this=exp.DataType.Type.UNKNOWN)
raise SchemaError(f"Could not convert table '{table}'")
- def _convert_type(self, schema_type: str) -> exp.DataType.Type:
+ def _convert_type(self, schema_type: str) -> exp.DataType:
"""
Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
@@ -257,7 +257,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
if expression is None:
raise ValueError(f"Could not parse {schema_type}")
- self._type_mapping_cache[schema_type] = expression.this
+ self._type_mapping_cache[schema_type] = expression # type: ignore
except AttributeError:
raise SchemaError(f"Failed to convert type {schema_type}")
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 8a7a38e..b25ef8d 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -49,6 +49,9 @@ class TokenType(AutoName):
PARAMETER = auto()
SESSION_PARAMETER = auto()
+ BLOCK_START = auto()
+ BLOCK_END = auto()
+
SPACE = auto()
BREAK = auto()
@@ -156,6 +159,7 @@ class TokenType(AutoName):
DIV = auto()
DROP = auto()
ELSE = auto()
+ ENCODE = auto()
END = auto()
ENGINE = auto()
ESCAPE = auto()
@@ -207,6 +211,7 @@ class TokenType(AutoName):
LOCATION = auto()
MAP = auto()
MATERIALIZED = auto()
+ MERGE = auto()
MOD = auto()
NATURAL = auto()
NEXT = auto()
@@ -255,6 +260,7 @@ class TokenType(AutoName):
SELECT = auto()
SEMI = auto()
SEPARATOR = auto()
+ SERDE_PROPERTIES = auto()
SET = auto()
SHOW = auto()
SIMILAR_TO = auto()
@@ -267,7 +273,6 @@ class TokenType(AutoName):
TABLE_FORMAT = auto()
TABLE_SAMPLE = auto()
TEMPORARY = auto()
- TRANSIENT = auto()
TOP = auto()
THEN = auto()
TRAILING = auto()
@@ -420,6 +425,16 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"]
KEYWORDS = {
+ **{
+ f"{key}{postfix}": TokenType.BLOCK_START
+ for key in ("{{", "{%", "{#")
+ for postfix in ("", "+", "-")
+ },
+ **{
+ f"{prefix}{key}": TokenType.BLOCK_END
+ for key in ("}}", "%}", "#}")
+ for prefix in ("", "+", "-")
+ },
"/*+": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
@@ -523,6 +538,7 @@ class Tokenizer(metaclass=_Tokenizer):
"LOCAL": TokenType.LOCAL,
"LOCATION": TokenType.LOCATION,
"MATERIALIZED": TokenType.MATERIALIZED,
+ "MERGE": TokenType.MERGE,
"NATURAL": TokenType.NATURAL,
"NEXT": TokenType.NEXT,
"NO ACTION": TokenType.NO_ACTION,
@@ -582,7 +598,6 @@ class Tokenizer(metaclass=_Tokenizer):
"TABLESAMPLE": TokenType.TABLE_SAMPLE,
"TEMP": TokenType.TEMPORARY,
"TEMPORARY": TokenType.TEMPORARY,
- "TRANSIENT": TokenType.TRANSIENT,
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,