summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-03-07 18:09:31 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-03-07 18:09:31 +0000
commitebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0 (patch)
treeeacad0719c5f2d113f221000ec126226f0d7fc9e /sqlglot
parentReleasing debian version 11.2.3-1. (diff)
downloadsqlglot-ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0.tar.xz
sqlglot-ebec59cc5cb6c6856705bf82ced7fe8d9f75b0d0.zip
Merging upstream version 11.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/column.py8
-rw-r--r--sqlglot/dataframe/sql/dataframe.py2
-rw-r--r--sqlglot/dataframe/sql/functions.py8
-rw-r--r--sqlglot/dialects/bigquery.py4
-rw-r--r--sqlglot/dialects/clickhouse.py3
-rw-r--r--sqlglot/dialects/dialect.py5
-rw-r--r--sqlglot/dialects/duckdb.py6
-rw-r--r--sqlglot/dialects/hive.py4
-rw-r--r--sqlglot/dialects/mysql.py3
-rw-r--r--sqlglot/dialects/oracle.py17
-rw-r--r--sqlglot/dialects/postgres.py7
-rw-r--r--sqlglot/dialects/snowflake.py41
-rw-r--r--sqlglot/dialects/teradata.py19
-rw-r--r--sqlglot/expressions.py38
-rw-r--r--sqlglot/generator.py81
-rw-r--r--sqlglot/optimizer/merge_subqueries.py24
-rw-r--r--sqlglot/optimizer/pushdown_projections.py6
-rw-r--r--sqlglot/parser.py179
-rw-r--r--sqlglot/tokens.py8
20 files changed, 339 insertions, 126 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 87b36b0..d026627 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -47,7 +47,7 @@ if t.TYPE_CHECKING:
T = t.TypeVar("T", bound=Expression)
-__version__ = "11.2.3"
+__version__ = "11.3.0"
pretty = False
"""Whether to format generated SQL by default."""
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index 609b2a4..f45d467 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -67,10 +67,10 @@ class Column:
return self.binary_op(exp.Mul, other)
def __truediv__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Div, other)
+ return self.binary_op(exp.FloatDiv, other)
def __div__(self, other: ColumnOrLiteral) -> Column:
- return self.binary_op(exp.Div, other)
+ return self.binary_op(exp.FloatDiv, other)
def __neg__(self) -> Column:
return self.unary_op(exp.Neg)
@@ -85,10 +85,10 @@ class Column:
return self.inverse_binary_op(exp.Mul, other)
def __rdiv__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Div, other)
+ return self.inverse_binary_op(exp.FloatDiv, other)
def __rtruediv__(self, other: ColumnOrLiteral) -> Column:
- return self.inverse_binary_op(exp.Div, other)
+ return self.inverse_binary_op(exp.FloatDiv, other)
def __rmod__(self, other: ColumnOrLiteral) -> Column:
return self.inverse_binary_op(exp.Mod, other)
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 93ca45a..32ee927 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -260,7 +260,7 @@ class DataFrame:
@classmethod
def _get_outer_select_columns(cls, item: t.Union[exp.Expression, DataFrame]) -> t.List[Column]:
expression = item.expression if isinstance(item, DataFrame) else item
- return [Column(x) for x in expression.find(exp.Select).expressions]
+ return [Column(x) for x in (expression.find(exp.Select) or exp.Select()).expressions]
@classmethod
def _create_hash_from_expression(cls, expression: exp.Select):
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 8f24746..3c98f42 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -954,10 +954,12 @@ def array_join(
col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
) -> Column:
if null_replacement is not None:
- return Column.invoke_anonymous_function(
- col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
+ return Column.invoke_expression_over_column(
+ col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
)
- return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
+ return Column.invoke_expression_over_column(
+ col, expression.ArrayJoin, expression=lit(delimiter)
+ )
def concat(*cols: ColumnOrName) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 32b5147..a3869c6 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -213,7 +213,11 @@ class BigQuery(Dialect):
),
}
+ INTEGER_DIVISION = False
+
class Generator(generator.Generator):
+ INTEGER_DIVISION = False
+
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.REMOVE_PRECISION_PARAMETERIZED_TYPES, # type: ignore
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index b553df2..a78d4db 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -56,6 +56,8 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore
+ INTEGER_DIVISION = False
+
def _parse_in(
self, this: t.Optional[exp.Expression], is_global: bool = False
) -> exp.Expression:
@@ -94,6 +96,7 @@ class ClickHouse(Dialect):
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
+ INTEGER_DIVISION = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index af36256..6939705 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -360,10 +360,9 @@ def create_with_partitions_sql(self: Generator, expression: exp.Create) -> str:
if has_schema and is_partitionable:
expression = expression.copy()
prop = expression.find(exp.PartitionedByProperty)
- this = prop and prop.this
- if prop and not isinstance(this, exp.Schema):
+ if prop and prop.this and not isinstance(prop.this, exp.Schema):
schema = expression.this
- columns = {v.name.upper() for v in this.expressions}
+ columns = {v.name.upper() for v in prop.this.expressions}
partitions = [col for col in schema.expressions if col.name.upper() in columns]
schema.set("expressions", [e for e in schema.expressions if e not in partitions])
prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 6144101..c2755cd 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -83,6 +83,7 @@ class DuckDB(Dialect):
":=": TokenType.EQ,
"ATTACH": TokenType.COMMAND,
"CHARACTER VARYING": TokenType.VARCHAR,
+ "EXCLUDE": TokenType.EXCEPT,
}
class Parser(parser.Parser):
@@ -173,3 +174,8 @@ class DuckDB(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}
+
+ STAR_MAPPING = {
+ **generator.Generator.STAR_MAPPING,
+ "except": "EXCLUDE",
+ }
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index ea1191e..44cd875 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -256,7 +256,11 @@ class Hive(Dialect):
),
}
+ INTEGER_DIVISION = False
+
class Generator(generator.Generator):
+ INTEGER_DIVISION = False
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 836bf3c..b1e20bd 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -300,6 +300,8 @@ class MySQL(Dialect):
"READ ONLY",
}
+ INTEGER_DIVISION = False
+
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
@@ -432,6 +434,7 @@ class MySQL(Dialect):
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
NULL_ORDERING_SUPPORTED = False
+ INTEGER_DIVISION = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 74baa8a..795bbeb 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -82,8 +82,17 @@ class Oracle(Dialect):
"XMLTABLE": _parse_xml_table,
}
+ INTEGER_DIVISION = False
+
+ def _parse_column(self) -> t.Optional[exp.Expression]:
+ column = super()._parse_column()
+ if column:
+ column.set("join_mark", self._match(TokenType.JOIN_MARKER))
+ return column
+
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
+ INTEGER_DIVISION = False
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -108,6 +117,8 @@ class Oracle(Dialect):
exp.Trim: trim_sql,
exp.Matches: rename_func("DECODE"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Subquery: lambda self, e: self.subquery_sql(e, sep=" "),
+ exp.Table: lambda self, e: self.table_sql(e, sep=" "),
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: lambda self, e: f"TO_DATE('1970-01-01','YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
exp.Substring: rename_func("SUBSTR"),
@@ -139,8 +150,9 @@ class Oracle(Dialect):
def offset_sql(self, expression: exp.Offset) -> str:
return f"{super().offset_sql(expression)} ROWS"
- def table_sql(self, expression: exp.Table, sep: str = " ") -> str:
- return super().table_sql(expression, sep=sep)
+ def column_sql(self, expression: exp.Column) -> str:
+ column = super().column_sql(expression)
+ return f"{column} (+)" if expression.args.get("join_mark") else column
def xmltable_sql(self, expression: exp.XMLTable) -> str:
this = self.sql(expression, "this")
@@ -156,6 +168,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "(+)": TokenType.JOIN_MARKER,
"COLUMNS": TokenType.COLUMN,
"MATCH_RECOGNIZE": TokenType.MATCH_RECOGNIZE,
"MINUS": TokenType.EXCEPT,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 3507cb5..35076db 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -222,10 +222,8 @@ class Postgres(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"CHARACTER VARYING": TokenType.VARCHAR,
- "COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
- "GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
@@ -260,10 +258,7 @@ class Postgres(Dialect):
TokenType.HASH: exp.BitwiseXor,
}
- FACTOR = {
- **parser.Parser.FACTOR, # type: ignore
- TokenType.CARET: exp.Pow,
- }
+ FACTOR = {**parser.Parser.FACTOR, TokenType.CARET: exp.Pow}
class Generator(generator.Generator):
LOCKING_READS_SUPPORTED = True
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 5931364..4a090c2 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
@@ -104,6 +106,20 @@ def _parse_date_part(self):
return self.expression(exp.Extract, this=this, expression=expression)
+# https://docs.snowflake.com/en/sql-reference/functions/div0
+def _div0_to_if(args):
+ cond = exp.EQ(this=seq_get(args, 1), expression=exp.Literal.number(0))
+ true = exp.Literal.number(0)
+ false = exp.FloatDiv(this=seq_get(args, 0), expression=seq_get(args, 1))
+ return exp.If(this=cond, true=true, false=false)
+
+
+# https://docs.snowflake.com/en/sql-reference/functions/zeroifnull
+def _zeroifnull_to_if(args):
+ cond = exp.EQ(this=seq_get(args, 0), expression=exp.Null())
+ return exp.If(this=cond, true=exp.Literal.number(0), false=seq_get(args, 0))
+
+
def _datatype_sql(self, expression):
if expression.this == exp.DataType.Type.ARRAY:
return "ARRAY"
@@ -150,16 +166,20 @@ class Snowflake(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
+ "ARRAY_TO_STRING": exp.ArrayJoin.from_arg_list,
"DATE_TRUNC": lambda args: exp.DateTrunc(
unit=exp.Literal.string(seq_get(args, 0).name), # type: ignore
this=seq_get(args, 1),
),
+ "DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
+ "TO_ARRAY": exp.Array.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
"ARRAY_CONSTRUCT": exp.Array.from_arg_list,
"RLIKE": exp.RegexpLike.from_arg_list,
"DECODE": exp.Matches.from_arg_list,
"OBJECT_CONSTRUCT": parser.parse_var_map,
+ "ZEROIFNULL": _zeroifnull_to_if,
}
FUNCTION_PARSERS = {
@@ -193,6 +213,19 @@ class Snowflake(Dialect):
),
}
+ ALTER_PARSERS = {
+ **parser.Parser.ALTER_PARSERS, # type: ignore
+ "UNSET": lambda self: self._parse_alter_table_set_tag(unset=True),
+ "SET": lambda self: self._parse_alter_table_set_tag(),
+ }
+
+ INTEGER_DIVISION = False
+
+ def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
+ self._match_text_seq("TAG")
+ parser = t.cast(t.Callable, self._parse_id_var if unset else self._parse_conjunction)
+ return self.expression(exp.SetTag, expressions=self._parse_csv(parser), unset=unset)
+
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
@@ -220,12 +253,14 @@ class Snowflake(Dialect):
class Generator(generator.Generator):
PARAMETER_TOKEN = "$"
+ INTEGER_DIVISION = False
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
- exp.DateAdd: rename_func("DATEADD"),
+ exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
+ exp.DateAdd: lambda self, e: self.func("DATEADD", e.text("unit"), e.expression, e.this),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@@ -294,6 +329,10 @@ class Snowflake(Dialect):
return self.no_identify(lambda: super(self.__class__, self).values_sql(expression))
return super().values_sql(expression)
+ def settag_sql(self, expression: exp.SetTag) -> str:
+ action = "UNSET" if expression.args.get("unset") else "SET"
+ return f"{action} TAG {self.expressions(expression)}"
+
def select_sql(self, expression: exp.Select) -> str:
"""Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also
that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 7953bc5..415681c 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -74,6 +74,7 @@ class Teradata(Dialect):
FUNCTION_PARSERS = {
**parser.Parser.FUNCTION_PARSERS, # type: ignore
+ "RANGE_N": lambda self: self._parse_rangen(),
"TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
}
@@ -105,6 +106,15 @@ class Teradata(Dialect):
},
)
+ def _parse_rangen(self):
+ this = self._parse_id_var()
+ self._match(TokenType.BETWEEN)
+
+ expressions = self._parse_csv(self._parse_conjunction)
+ each = self._match_text_seq("EACH") and self._parse_conjunction()
+
+ return self.expression(exp.RangeN, this=this, expressions=expressions, each=each)
+
class Generator(generator.Generator):
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
@@ -114,7 +124,6 @@ class Teradata(Dialect):
PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION, # type: ignore
exp.PartitionedByProperty: exp.Properties.Location.POST_INDEX,
- exp.VolatilityProperty: exp.Properties.Location.POST_CREATE,
}
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
@@ -137,3 +146,11 @@ class Teradata(Dialect):
type_sql = super().datatype_sql(expression)
prefix_sql = expression.args.get("prefix")
return f"SYSUDTLIB.{type_sql}" if prefix_sql else type_sql
+
+ def rangen_sql(self, expression: exp.RangeN) -> str:
+ this = self.sql(expression, "this")
+ expressions_sql = self.expressions(expression)
+ each_sql = self.sql(expression, "each")
+ each_sql = f" EACH {each_sql}" if each_sql else ""
+
+ return f"RANGE_N({this} BETWEEN {expressions_sql}{each_sql})"
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 59881d6..00a3b45 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -35,6 +35,8 @@ from sqlglot.tokens import Token
if t.TYPE_CHECKING:
from sqlglot.dialects.dialect import DialectType
+E = t.TypeVar("E", bound="Expression")
+
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@@ -293,7 +295,7 @@ class Expression(metaclass=_Expression):
return self.parent.depth + 1
return 0
- def find(self, *expression_types, bfs=True):
+ def find(self, *expression_types: t.Type[E], bfs=True) -> E | None:
"""
Returns the first node in this tree which matches at least one of
the specified types.
@@ -306,7 +308,7 @@ class Expression(metaclass=_Expression):
"""
return next(self.find_all(*expression_types, bfs=bfs), None)
- def find_all(self, *expression_types, bfs=True):
+ def find_all(self, *expression_types: t.Type[E], bfs=True) -> t.Iterator[E]:
"""
Returns a generator object which visits all nodes in this tree and only
yields those that match at least one of the specified expression types.
@@ -321,7 +323,7 @@ class Expression(metaclass=_Expression):
if isinstance(expression, expression_types):
yield expression
- def find_ancestor(self, *expression_types):
+ def find_ancestor(self, *expression_types: t.Type[E]) -> E | None:
"""
Returns a nearest parent matching expression_types.
@@ -334,7 +336,8 @@ class Expression(metaclass=_Expression):
ancestor = self.parent
while ancestor and not isinstance(ancestor, expression_types):
ancestor = ancestor.parent
- return ancestor
+ # ignore type because mypy doesn't know that we're checking type in the loop
+ return ancestor # type: ignore[return-value]
@property
def parent_select(self):
@@ -794,6 +797,7 @@ class Create(Expression):
"properties": False,
"replace": False,
"unique": False,
+ "volatile": False,
"indexes": False,
"no_schema_binding": False,
"begin": False,
@@ -883,7 +887,7 @@ class ByteString(Condition):
class Column(Condition):
- arg_types = {"this": True, "table": False, "db": False, "catalog": False}
+ arg_types = {"this": True, "table": False, "db": False, "catalog": False, "join_mark": False}
@property
def table(self) -> str:
@@ -926,6 +930,14 @@ class RenameTable(Expression):
pass
+class SetTag(Expression):
+ arg_types = {"expressions": True, "unset": False}
+
+
+class Comment(Expression):
+ arg_types = {"this": True, "kind": True, "expression": True, "exists": False}
+
+
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@@ -2829,6 +2841,14 @@ class Div(Binary):
pass
+class FloatDiv(Binary):
+ pass
+
+
+class Overlaps(Binary):
+ pass
+
+
class Dot(Binary):
@property
def name(self) -> str:
@@ -3125,6 +3145,10 @@ class ArrayFilter(Func):
_sql_names = ["FILTER", "ARRAY_FILTER"]
+class ArrayJoin(Func):
+ arg_types = {"this": True, "expression": True, "null": False}
+
+
class ArraySize(Func):
arg_types = {"this": True, "expression": False}
@@ -3510,6 +3534,10 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
+class RangeN(Func):
+ arg_types = {"this": True, "expressions": True, "each": False}
+
+
class ReadCSV(Func):
_sql_names = ["READ_CSV"]
is_var_len_args = True
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 0a7a81f..79501ef 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -109,6 +109,9 @@ class Generator:
# Whether or not create function uses an AS before the RETURN
CREATE_FUNCTION_RETURN_AS = True
+ # Whether or not to treat the division operator "/" as integer division
+ INTEGER_DIVISION = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -550,14 +553,17 @@ class Generator:
else:
expression_sql = f" AS{expression_sql}"
- replace = " OR REPLACE" if expression.args.get("replace") else ""
- unique = " UNIQUE" if expression.args.get("unique") else ""
- exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
+ postindex_props_sql = ""
+ if properties_locs.get(exp.Properties.Location.POST_INDEX):
+ postindex_props_sql = self.properties(
+ exp.Properties(expressions=properties_locs[exp.Properties.Location.POST_INDEX]),
+ wrapped=False,
+ prefix=" ",
+ )
indexes = expression.args.get("indexes")
- index_sql = ""
if indexes:
- indexes_sql = []
+ indexes_sql: t.List[str] = []
for index in indexes:
ind_unique = " UNIQUE" if index.args.get("unique") else ""
ind_primary = " PRIMARY" if index.args.get("primary") else ""
@@ -568,21 +574,24 @@ class Generator:
if index.args.get("columns")
else ""
)
- if index.args.get("primary") and properties_locs.get(
- exp.Properties.Location.POST_INDEX
- ):
- postindex_props_sql = self.properties(
- exp.Properties(
- expressions=properties_locs[exp.Properties.Location.POST_INDEX]
- ),
- wrapped=False,
+ ind_sql = f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
+
+ if indexes_sql:
+ indexes_sql.append(ind_sql)
+ else:
+ indexes_sql.append(
+ f"{ind_sql}{postindex_props_sql}"
+ if index.args.get("primary")
+ else f"{postindex_props_sql}{ind_sql}"
)
- ind_columns = f"{ind_columns} {postindex_props_sql}"
- indexes_sql.append(
- f"{ind_unique}{ind_primary}{ind_amp} INDEX{ind_name}{ind_columns}"
- )
index_sql = "".join(indexes_sql)
+ else:
+ index_sql = postindex_props_sql
+
+ replace = " OR REPLACE" if expression.args.get("replace") else ""
+ unique = " UNIQUE" if expression.args.get("unique") else ""
+ volatile = " VOLATILE" if expression.args.get("volatile") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
@@ -593,7 +602,7 @@ class Generator:
wrapped=False,
)
- modifiers = "".join((replace, unique, postcreate_props_sql))
+ modifiers = "".join((replace, unique, volatile, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
@@ -606,6 +615,7 @@ class Generator:
wrapped=False,
)
+ exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
no_schema_binding = (
" WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
)
@@ -1335,14 +1345,15 @@ class Generator:
def placeholder_sql(self, expression: exp.Placeholder) -> str:
return f":{expression.name}" if expression.name else "?"
- def subquery_sql(self, expression: exp.Subquery) -> str:
+ def subquery_sql(self, expression: exp.Subquery, sep: str = " AS ") -> str:
alias = self.sql(expression, "alias")
+ alias = f"{sep}{alias}" if alias else ""
sql = self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
- f" AS {alias}" if alias else "",
+ alias,
)
return self.prepend_ctes(expression, sql)
@@ -1643,6 +1654,13 @@ class Generator:
def command_sql(self, expression: exp.Command) -> str:
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
+ def comment_sql(self, expression: exp.Comment) -> str:
+ this = self.sql(expression, "this")
+ kind = expression.args["kind"]
+ exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
+ expression_sql = self.sql(expression, "expression")
+ return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}"
+
def transaction_sql(self, *_) -> str:
return "BEGIN"
@@ -1728,19 +1746,30 @@ class Generator:
return f"{self.sql(expression, 'this')} RESPECT NULLS"
def intdiv_sql(self, expression: exp.IntDiv) -> str:
- return self.sql(
- exp.Cast(
- this=exp.Div(this=expression.this, expression=expression.expression),
- to=exp.DataType(this=exp.DataType.Type.INT),
- )
- )
+ div = self.binary(expression, "/")
+ return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT")))
def dpipe_sql(self, expression: exp.DPipe) -> str:
return self.binary(expression, "||")
def div_sql(self, expression: exp.Div) -> str:
+ div = self.binary(expression, "/")
+
+ if not self.INTEGER_DIVISION:
+ return self.sql(exp.Cast(this=div, to=exp.DataType.build("INT")))
+
+ return div
+
+ def floatdiv_sql(self, expression: exp.FloatDiv) -> str:
+ if self.INTEGER_DIVISION:
+ this = exp.Cast(this=expression.this, to=exp.DataType.build("DOUBLE"))
+ return self.div_sql(exp.Div(this=this, expression=expression.expression))
+
return self.binary(expression, "/")
+ def overlaps_sql(self, expression: exp.Overlaps) -> str:
+ return self.binary(expression, "OVERLAPS")
+
def distance_sql(self, expression: exp.Distance) -> str:
return self.binary(expression, "<->")
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 16aaf17..70172f4 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -314,13 +314,27 @@ def _merge_where(outer_scope, inner_scope, from_or_join):
if not where or not where.this:
return
+ expression = outer_scope.expression
+
if isinstance(from_or_join, exp.Join):
# Merge predicates from an outer join to the ON clause
- from_or_join.on(where.this, copy=False)
- from_or_join.set("on", simplify(from_or_join.args.get("on")))
- else:
- outer_scope.expression.where(where.this, copy=False)
- outer_scope.expression.set("where", simplify(outer_scope.expression.args.get("where")))
+ # if it only has columns that are already joined
+ from_ = expression.args.get("from")
+ sources = {table.alias_or_name for table in from_.expressions} if from_ else {}
+
+ for join in expression.args["joins"]:
+ source = join.alias_or_name
+ sources.add(source)
+ if source == from_or_join.alias_or_name:
+ break
+
+ if set(exp.column_table_names(where.this)) <= sources:
+ from_or_join.on(where.this, copy=False)
+ from_or_join.set("on", simplify(from_or_join.args.get("on")))
+ return
+
+ expression.where(where.this, copy=False)
+ expression.set("where", simplify(expression.args.get("where")))
def _merge_order(outer_scope, inner_scope):
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 3f360f9..07a1b70 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -13,7 +13,7 @@ SELECT_ALL = object()
DEFAULT_SELECTION = lambda: alias("1", "_")
-def pushdown_projections(expression, schema=None):
+def pushdown_projections(expression, schema=None, remove_unused_selections=True):
"""
Rewrite sqlglot AST to remove unused columns projections.
@@ -26,6 +26,7 @@ def pushdown_projections(expression, schema=None):
Args:
expression (sqlglot.Expression): expression to optimize
+ remove_unused_selections (bool): remove selects that are unused
Returns:
sqlglot.Expression: optimized expression
"""
@@ -57,7 +58,8 @@ def pushdown_projections(expression, schema=None):
]
if isinstance(scope.expression, exp.Select):
- _remove_unused_selections(scope, parent_selections, schema)
+ if remove_unused_selections:
+ _remove_unused_selections(scope, parent_selections, schema)
# Group columns by source name
selects = defaultdict(set)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 9f32765..f39bb39 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -36,6 +36,10 @@ class _Parser(type):
klass = super().__new__(cls, clsname, bases, attrs)
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
+
+ if not klass.INTEGER_DIVISION:
+ klass.FACTOR = {**klass.FACTOR, TokenType.SLASH: exp.FloatDiv}
+
return klass
@@ -157,6 +161,21 @@ class Parser(metaclass=_Parser):
RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT}
+ DB_CREATABLES = {
+ TokenType.DATABASE,
+ TokenType.SCHEMA,
+ TokenType.TABLE,
+ TokenType.VIEW,
+ }
+
+ CREATABLES = {
+ TokenType.COLUMN,
+ TokenType.FUNCTION,
+ TokenType.INDEX,
+ TokenType.PROCEDURE,
+ *DB_CREATABLES,
+ }
+
ID_VAR_TOKENS = {
TokenType.VAR,
TokenType.ANTI,
@@ -168,8 +187,8 @@ class Parser(metaclass=_Parser):
TokenType.CACHE,
TokenType.CASCADE,
TokenType.COLLATE,
- TokenType.COLUMN,
TokenType.COMMAND,
+ TokenType.COMMENT,
TokenType.COMMIT,
TokenType.COMPOUND,
TokenType.CONSTRAINT,
@@ -186,9 +205,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FOLLOWING,
TokenType.FORMAT,
- TokenType.FUNCTION,
TokenType.IF,
- TokenType.INDEX,
TokenType.ISNULL,
TokenType.INTERVAL,
TokenType.LAZY,
@@ -211,13 +228,11 @@ class Parser(metaclass=_Parser):
TokenType.RIGHT,
TokenType.ROW,
TokenType.ROWS,
- TokenType.SCHEMA,
TokenType.SEED,
TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
TokenType.SORTKEY,
- TokenType.TABLE,
TokenType.TEMPORARY,
TokenType.TOP,
TokenType.TRAILING,
@@ -226,10 +241,9 @@ class Parser(metaclass=_Parser):
TokenType.UNIQUE,
TokenType.UNLOGGED,
TokenType.UNPIVOT,
- TokenType.PROCEDURE,
- TokenType.VIEW,
TokenType.VOLATILE,
TokenType.WINDOW,
+ *CREATABLES,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
*NO_PAREN_FUNCTIONS,
@@ -428,6 +442,7 @@ class Parser(metaclass=_Parser):
TokenType.BEGIN: lambda self: self._parse_transaction(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
+ TokenType.COMMENT: lambda self: self._parse_comment(),
TokenType.CREATE: lambda self: self._parse_create(),
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.DESC: lambda self: self._parse_describe(),
@@ -490,6 +505,9 @@ class Parser(metaclass=_Parser):
TokenType.GLOB: lambda self, this: self._parse_escape(
self.expression(exp.Glob, this=this, expression=self._parse_bitwise())
),
+ TokenType.OVERLAPS: lambda self, this: self._parse_escape(
+ self.expression(exp.Overlaps, this=this, expression=self._parse_bitwise())
+ ),
TokenType.IN: lambda self, this: self._parse_in(this),
TokenType.IS: lambda self, this: self._parse_is(this),
TokenType.LIKE: lambda self, this: self._parse_escape(
@@ -628,6 +646,14 @@ class Parser(metaclass=_Parser):
"UPPERCASE": lambda self: self.expression(exp.UppercaseColumnConstraint),
}
+ ALTER_PARSERS = {
+ "ADD": lambda self: self._parse_alter_table_add(),
+ "ALTER": lambda self: self._parse_alter_table_alter(),
+ "DELETE": lambda self: self.expression(exp.Delete, where=self._parse_where()),
+ "DROP": lambda self: self._parse_alter_table_drop(),
+ "RENAME": lambda self: self._parse_alter_table_rename(),
+ }
+
SCHEMA_UNNAMED_CONSTRAINTS = {"CHECK", "FOREIGN KEY", "LIKE", "PRIMARY KEY", "UNIQUE"}
NO_PAREN_FUNCTION_PARSERS = {
@@ -669,16 +695,6 @@ class Parser(metaclass=_Parser):
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
- CREATABLES = {
- TokenType.COLUMN,
- TokenType.FUNCTION,
- TokenType.INDEX,
- TokenType.PROCEDURE,
- TokenType.SCHEMA,
- TokenType.TABLE,
- TokenType.VIEW,
- }
-
TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
INSERT_ALTERNATIVES = {"ABORT", "FAIL", "IGNORE", "REPLACE", "ROLLBACK"}
@@ -689,6 +705,8 @@ class Parser(metaclass=_Parser):
STRICT_CAST = True
+ INTEGER_DIVISION = True
+
__slots__ = (
"error_level",
"error_message_context",
@@ -940,6 +958,32 @@ class Parser(metaclass=_Parser):
def _parse_command(self) -> exp.Expression:
return self.expression(exp.Command, this=self._prev.text, expression=self._parse_string())
+ def _parse_comment(self, allow_exists: bool = True) -> exp.Expression:
+ start = self._prev
+ exists = self._parse_exists() if allow_exists else None
+
+ self._match(TokenType.ON)
+
+ kind = self._match_set(self.CREATABLES) and self._prev
+
+ if not kind:
+ return self._parse_as_command(start)
+
+ if kind.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
+ this = self._parse_user_defined_function(kind=kind.token_type)
+ elif kind.token_type == TokenType.TABLE:
+ this = self._parse_table()
+ elif kind.token_type == TokenType.COLUMN:
+ this = self._parse_column()
+ else:
+ this = self._parse_id_var()
+
+ self._match(TokenType.IS)
+
+ return self.expression(
+ exp.Comment, this=this, kind=kind.text, expression=self._parse_string(), exists=exists
+ )
+
def _parse_statement(self) -> t.Optional[exp.Expression]:
if self._curr is None:
return None
@@ -990,6 +1034,7 @@ class Parser(metaclass=_Parser):
TokenType.OR, TokenType.REPLACE
)
unique = self._match(TokenType.UNIQUE)
+ volatile = self._match(TokenType.VOLATILE)
if self._match_pair(TokenType.TABLE, TokenType.FUNCTION, advance=False):
self._match(TokenType.TABLE)
@@ -1028,11 +1073,7 @@ class Parser(metaclass=_Parser):
expression = self.expression(exp.Return, this=expression)
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
- elif create_token.token_type in (
- TokenType.TABLE,
- TokenType.VIEW,
- TokenType.SCHEMA,
- ):
+ elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)
# exp.Properties.Location.POST_NAME
@@ -1100,11 +1141,12 @@ class Parser(metaclass=_Parser):
exp.Create,
this=this,
kind=create_token.text,
+ replace=replace,
unique=unique,
+ volatile=volatile,
expression=expression,
exists=exists,
properties=properties,
- replace=replace,
indexes=indexes,
no_schema_binding=no_schema_binding,
begin=begin,
@@ -3648,6 +3690,47 @@ class Parser(metaclass=_Parser):
return self.expression(exp.AddConstraint, this=this, expression=expression)
+ def _parse_alter_table_add(self) -> t.List[t.Optional[exp.Expression]]:
+ index = self._index - 1
+
+ if self._match_set(self.ADD_CONSTRAINT_TOKENS):
+ return self._parse_csv(self._parse_add_constraint)
+
+ self._retreat(index)
+ return self._parse_csv(self._parse_add_column)
+
+ def _parse_alter_table_alter(self) -> exp.Expression:
+ self._match(TokenType.COLUMN)
+ column = self._parse_field(any_token=True)
+
+ if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
+ return self.expression(exp.AlterColumn, this=column, drop=True)
+ if self._match_pair(TokenType.SET, TokenType.DEFAULT):
+ return self.expression(exp.AlterColumn, this=column, default=self._parse_conjunction())
+
+ self._match_text_seq("SET", "DATA")
+ return self.expression(
+ exp.AlterColumn,
+ this=column,
+ dtype=self._match_text_seq("TYPE") and self._parse_types(),
+ collate=self._match(TokenType.COLLATE) and self._parse_term(),
+ using=self._match(TokenType.USING) and self._parse_conjunction(),
+ )
+
+ def _parse_alter_table_drop(self) -> t.List[t.Optional[exp.Expression]]:
+ index = self._index - 1
+
+ partition_exists = self._parse_exists()
+ if self._match(TokenType.PARTITION, advance=False):
+ return self._parse_csv(lambda: self._parse_drop_partition(exists=partition_exists))
+
+ self._retreat(index)
+ return self._parse_csv(self._parse_drop_column)
+
+ def _parse_alter_table_rename(self) -> exp.Expression:
+ self._match_text_seq("TO")
+ return self.expression(exp.RenameTable, this=self._parse_table(schema=True))
+
def _parse_alter(self) -> t.Optional[exp.Expression]:
if not self._match(TokenType.TABLE):
return self._parse_as_command(self._prev)
@@ -3655,50 +3738,12 @@ class Parser(metaclass=_Parser):
exists = self._parse_exists()
this = self._parse_table(schema=True)
- actions: t.Optional[exp.Expression | t.List[t.Optional[exp.Expression]]] = None
-
- index = self._index
- if self._match(TokenType.DELETE):
- actions = [self.expression(exp.Delete, where=self._parse_where())]
- elif self._match_text_seq("ADD"):
- if self._match_set(self.ADD_CONSTRAINT_TOKENS):
- actions = self._parse_csv(self._parse_add_constraint)
- else:
- self._retreat(index)
- actions = self._parse_csv(self._parse_add_column)
- elif self._match_text_seq("DROP"):
- partition_exists = self._parse_exists()
+ if not self._curr:
+ return None
- if self._match(TokenType.PARTITION, advance=False):
- actions = self._parse_csv(
- lambda: self._parse_drop_partition(exists=partition_exists)
- )
- else:
- self._retreat(index)
- actions = self._parse_csv(self._parse_drop_column)
- elif self._match_text_seq("RENAME", "TO"):
- actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
- elif self._match_text_seq("ALTER"):
- self._match(TokenType.COLUMN)
- column = self._parse_field(any_token=True)
-
- if self._match_pair(TokenType.DROP, TokenType.DEFAULT):
- actions = self.expression(exp.AlterColumn, this=column, drop=True)
- elif self._match_pair(TokenType.SET, TokenType.DEFAULT):
- actions = self.expression(
- exp.AlterColumn, this=column, default=self._parse_conjunction()
- )
- else:
- self._match_text_seq("SET", "DATA")
- actions = self.expression(
- exp.AlterColumn,
- this=column,
- dtype=self._match_text_seq("TYPE") and self._parse_types(),
- collate=self._match(TokenType.COLLATE) and self._parse_term(),
- using=self._match(TokenType.USING) and self._parse_conjunction(),
- )
+ parser = self.ALTER_PARSERS.get(self._curr.text.upper())
+ actions = ensure_list(self._advance() or parser(self)) if parser else [] # type: ignore
- actions = ensure_list(actions)
return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions)
def _parse_show(self) -> t.Optional[exp.Expression]:
@@ -3772,7 +3817,9 @@ class Parser(metaclass=_Parser):
def _parse_as_command(self, start: Token) -> exp.Command:
while self._curr:
self._advance()
- return exp.Command(this=self._find_sql(start, self._prev))
+ text = self._find_sql(start, self._prev)
+ size = len(start.text)
+ return exp.Command(this=text[:size], expression=text[size:])
def _find_parser(
self, parsers: t.Dict[str, t.Callable], trie: t.Dict
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index f3f1a70..7a23803 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -60,6 +60,7 @@ class TokenType(AutoName):
STRING = auto()
NUMBER = auto()
IDENTIFIER = auto()
+ DATABASE = auto()
COLUMN = auto()
COLUMN_DEF = auto()
SCHEMA = auto()
@@ -203,6 +204,7 @@ class TokenType(AutoName):
IS = auto()
ISNULL = auto()
JOIN = auto()
+ JOIN_MARKER = auto()
LANGUAGE = auto()
LATERAL = auto()
LAZY = auto()
@@ -235,6 +237,7 @@ class TokenType(AutoName):
OUTER = auto()
OUT_OF = auto()
OVER = auto()
+ OVERLAPS = auto()
OVERWRITE = auto()
PARTITION = auto()
PARTITION_BY = auto()
@@ -491,6 +494,7 @@ class Tokenizer(metaclass=_Tokenizer):
"CURRENT_DATE": TokenType.CURRENT_DATE,
"CURRENT ROW": TokenType.CURRENT_ROW,
"CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
+ "DATABASE": TokenType.DATABASE,
"DEFAULT": TokenType.DEFAULT,
"DELETE": TokenType.DELETE,
"DESC": TokenType.DESC,
@@ -564,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
"OUTER": TokenType.OUTER,
"OUT OF": TokenType.OUT_OF,
"OVER": TokenType.OVER,
+ "OVERLAPS": TokenType.OVERLAPS,
"OVERWRITE": TokenType.OVERWRITE,
"PARTITION": TokenType.PARTITION,
"PARTITION BY": TokenType.PARTITION_BY,
@@ -652,6 +657,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DOUBLE PRECISION": TokenType.DOUBLE,
"JSON": TokenType.JSON,
"CHAR": TokenType.CHAR,
+ "CHARACTER": TokenType.CHAR,
"NCHAR": TokenType.NCHAR,
"VARCHAR": TokenType.VARCHAR,
"VARCHAR2": TokenType.VARCHAR,
@@ -687,8 +693,10 @@ class Tokenizer(metaclass=_Tokenizer):
"ALTER VIEW": TokenType.COMMAND,
"ANALYZE": TokenType.COMMAND,
"CALL": TokenType.COMMAND,
+ "COMMENT": TokenType.COMMENT,
"COPY": TokenType.COMMAND,
"EXPLAIN": TokenType.COMMAND,
+ "GRANT": TokenType.COMMAND,
"OPTIMIZE": TokenType.COMMAND,
"PREPARE": TokenType.COMMAND,
"TRUNCATE": TokenType.COMMAND,