summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:35 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-11 08:54:35 +0000
commitd1f00706bff58b863b0a1c5bf4adf39d36049d4c (patch)
tree3a8ecc5d1509d655d5df6b1455bc1e309da2c02c /sqlglot
parentReleasing debian version 9.0.6-1. (diff)
downloadsqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.tar.xz
sqlglot-d1f00706bff58b863b0a1c5bf4adf39d36049d4c.zip
Merging upstream version 10.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py71
-rw-r--r--sqlglot/__main__.py5
-rw-r--r--sqlglot/dataframe/sql/_typing.pyi14
-rw-r--r--sqlglot/dataframe/sql/column.py46
-rw-r--r--sqlglot/dataframe/sql/dataframe.py158
-rw-r--r--sqlglot/dataframe/sql/functions.py100
-rw-r--r--sqlglot/dataframe/sql/group.py10
-rw-r--r--sqlglot/dataframe/sql/normalize.py13
-rw-r--r--sqlglot/dataframe/sql/readwriter.py16
-rw-r--r--sqlglot/dataframe/sql/session.py17
-rw-r--r--sqlglot/dataframe/sql/types.py6
-rw-r--r--sqlglot/dataframe/sql/window.py27
-rw-r--r--sqlglot/dialects/bigquery.py57
-rw-r--r--sqlglot/dialects/clickhouse.py24
-rw-r--r--sqlglot/dialects/databricks.py4
-rw-r--r--sqlglot/dialects/dialect.py52
-rw-r--r--sqlglot/dialects/duckdb.py33
-rw-r--r--sqlglot/dialects/hive.py57
-rw-r--r--sqlglot/dialects/mysql.py329
-rw-r--r--sqlglot/dialects/oracle.py20
-rw-r--r--sqlglot/dialects/postgres.py25
-rw-r--r--sqlglot/dialects/presto.py41
-rw-r--r--sqlglot/dialects/redshift.py13
-rw-r--r--sqlglot/dialects/snowflake.py46
-rw-r--r--sqlglot/dialects/spark.py37
-rw-r--r--sqlglot/dialects/sqlite.py24
-rw-r--r--sqlglot/dialects/starrocks.py7
-rw-r--r--sqlglot/dialects/tableau.py14
-rw-r--r--sqlglot/dialects/trino.py4
-rw-r--r--sqlglot/dialects/tsql.py54
-rw-r--r--sqlglot/diff.py23
-rw-r--r--sqlglot/errors.py9
-rw-r--r--sqlglot/executor/context.py44
-rw-r--r--sqlglot/executor/env.py4
-rw-r--r--sqlglot/executor/python.py190
-rw-r--r--sqlglot/executor/table.py27
-rw-r--r--sqlglot/expressions.py258
-rw-r--r--sqlglot/generator.py214
-rw-r--r--sqlglot/helper.py209
-rw-r--r--sqlglot/optimizer/annotate_types.py131
-rw-r--r--sqlglot/optimizer/eliminate_joins.py4
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py12
-rw-r--r--sqlglot/optimizer/merge_subqueries.py16
-rw-r--r--sqlglot/optimizer/normalize.py4
-rw-r--r--sqlglot/optimizer/optimize_joins.py6
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py28
-rw-r--r--sqlglot/optimizer/pushdown_projections.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py28
-rw-r--r--sqlglot/optimizer/scope.py14
-rw-r--r--sqlglot/optimizer/simplify.py12
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py14
-rw-r--r--sqlglot/parser.py410
-rw-r--r--sqlglot/planner.py19
-rw-r--r--sqlglot/py.typed0
-rw-r--r--sqlglot/schema.py298
-rw-r--r--sqlglot/time.py17
-rw-r--r--sqlglot/tokens.py247
-rw-r--r--sqlglot/transforms.py42
-rw-r--r--sqlglot/trie.py48
60 files changed, 2549 insertions, 1111 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index d6e18fd..6e67b19 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -1,5 +1,9 @@
"""## Python SQL parser, transpiler and optimizer."""
+from __future__ import annotations
+
+import typing as t
+
from sqlglot import expressions as exp
from sqlglot.dialects import Dialect, Dialects
from sqlglot.diff import diff
@@ -20,51 +24,54 @@ from sqlglot.expressions import (
subquery,
)
from sqlglot.expressions import table_ as table
-from sqlglot.expressions import union
+from sqlglot.expressions import to_column, to_table, union
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "9.0.6"
+__version__ = "10.0.1"
pretty = False
schema = MappingSchema()
-def parse(sql, read=None, **opts):
+def parse(
+ sql: str, read: t.Optional[str | Dialect] = None, **opts
+) -> t.List[t.Optional[Expression]]:
"""
- Parses the given SQL string into a collection of syntax trees, one per
- parsed SQL statement.
+ Parses the given SQL string into a collection of syntax trees, one per parsed SQL statement.
Args:
- sql (str): the SQL code string to parse.
- read (str): the SQL dialect to apply during parsing
- (eg. "spark", "hive", "presto", "mysql").
+ sql: the SQL code string to parse.
+ read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
**opts: other options.
Returns:
- typing.List[Expression]: the list of parsed syntax trees.
+ The resulting syntax tree collection.
"""
dialect = Dialect.get_or_raise(read)()
return dialect.parse(sql, **opts)
-def parse_one(sql, read=None, into=None, **opts):
+def parse_one(
+ sql: str,
+ read: t.Optional[str | Dialect] = None,
+ into: t.Optional[Expression | str] = None,
+ **opts,
+) -> t.Optional[Expression]:
"""
- Parses the given SQL string and returns a syntax tree for the first
- parsed SQL statement.
+ Parses the given SQL string and returns a syntax tree for the first parsed SQL statement.
Args:
- sql (str): the SQL code string to parse.
- read (str): the SQL dialect to apply during parsing
- (eg. "spark", "hive", "presto", "mysql").
- into (Expression): the SQLGlot Expression to parse into
+ sql: the SQL code string to parse.
+ read: the SQL dialect to apply during parsing (eg. "spark", "hive", "presto", "mysql").
+ into: the SQLGlot Expression to parse into.
**opts: other options.
Returns:
- Expression: the syntax tree for the first parsed statement.
+ The syntax tree for the first parsed statement.
"""
dialect = Dialect.get_or_raise(read)()
@@ -77,25 +84,29 @@ def parse_one(sql, read=None, into=None, **opts):
return result[0] if result else None
-def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts):
+def transpile(
+ sql: str,
+ read: t.Optional[str | Dialect] = None,
+ write: t.Optional[str | Dialect] = None,
+ identity: bool = True,
+ error_level: t.Optional[ErrorLevel] = None,
+ **opts,
+) -> t.List[str]:
"""
- Parses the given SQL string using the source dialect and returns a list of SQL strings
- transformed to conform to the target dialect. Each string in the returned list represents
- a single transformed SQL statement.
+ Parses the given SQL string in accordance with the source dialect and returns a list of SQL strings transformed
+ to conform to the target dialect. Each string in the returned list represents a single transformed SQL statement.
Args:
- sql (str): the SQL code string to transpile.
- read (str): the source dialect used to parse the input string
- (eg. "spark", "hive", "presto", "mysql").
- write (str): the target dialect into which the input should be transformed
- (eg. "spark", "hive", "presto", "mysql").
- identity (bool): if set to True and if the target dialect is not specified
- the source dialect will be used as both: the source and the target dialect.
- error_level (ErrorLevel): the desired error level of the parser.
+ sql: the SQL code string to transpile.
+ read: the source dialect used to parse the input string (eg. "spark", "hive", "presto", "mysql").
+ write: the target dialect into which the input should be transformed (eg. "spark", "hive", "presto", "mysql").
+ identity: if set to `True` and if the target dialect is not specified the source dialect will be used as both:
+ the source and the target dialect.
+ error_level: the desired error level of the parser.
**opts: other options.
Returns:
- typing.List[str]: the list of transpiled SQL statements / expressions.
+ The list of transpiled SQL statements.
"""
write = write or read if identity else write
return [
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
index c0fa380..42a54bc 100644
--- a/sqlglot/__main__.py
+++ b/sqlglot/__main__.py
@@ -49,7 +49,10 @@ args = parser.parse_args()
error_level = sqlglot.ErrorLevel[args.error_level.upper()]
if args.parse:
- sqls = [repr(expression) for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)]
+ sqls = [
+ repr(expression)
+ for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level)
+ ]
else:
sqls = sqlglot.transpile(
args.sql,
diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi
index f1a03ea..67c8c09 100644
--- a/sqlglot/dataframe/sql/_typing.pyi
+++ b/sqlglot/dataframe/sql/_typing.pyi
@@ -10,11 +10,17 @@ if t.TYPE_CHECKING:
from sqlglot.dataframe.sql.types import StructType
ColumnLiterals = t.TypeVar(
- "ColumnLiterals", bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ "ColumnLiterals",
+ bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
)
ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str])
ColumnOrLiteral = t.TypeVar(
- "ColumnOrLiteral", bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime]
+ "ColumnOrLiteral",
+ bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime],
+)
+SchemaInput = t.TypeVar(
+ "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]]
+)
+OutputExpressionContainer = t.TypeVar(
+ "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert]
)
-SchemaInput = t.TypeVar("SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]])
-OutputExpressionContainer = t.TypeVar("OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert])
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index e66aaa8..f9e1c5b 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -18,7 +18,11 @@ class Column:
expression = expression.expression # type: ignore
elif expression is None or not isinstance(expression, (str, exp.Expression)):
expression = self._lit(expression).expression # type: ignore
- self.expression: exp.Expression = sqlglot.maybe_parse(expression, dialect="spark")
+
+ expression = sqlglot.maybe_parse(expression, dialect="spark")
+ if expression is None:
+ raise ValueError(f"Could not parse {expression}")
+ self.expression: exp.Expression = expression
def __repr__(self):
return repr(self.expression)
@@ -135,21 +139,29 @@ class Column:
) -> Column:
ensured_column = None if column is None else cls.ensure_col(column)
ensure_expression_values = {
- k: [Column.ensure_col(x).expression for x in v] if is_iterable(v) else Column.ensure_col(v).expression
+ k: [Column.ensure_col(x).expression for x in v]
+ if is_iterable(v)
+ else Column.ensure_col(v).expression
for k, v in kwargs.items()
}
new_expression = (
callable_expression(**ensure_expression_values)
if ensured_column is None
- else callable_expression(this=ensured_column.column_expression, **ensure_expression_values)
+ else callable_expression(
+ this=ensured_column.column_expression, **ensure_expression_values
+ )
)
return Column(new_expression)
def binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs))
+ return Column(
+ klass(this=self.column_expression, expression=Column(other).column_expression, **kwargs)
+ )
def inverse_binary_op(self, klass: t.Callable, other: ColumnOrLiteral, **kwargs) -> Column:
- return Column(klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs))
+ return Column(
+ klass(this=Column(other).column_expression, expression=self.column_expression, **kwargs)
+ )
def unary_op(self, klass: t.Callable, **kwargs) -> Column:
return Column(klass(this=self.column_expression, **kwargs))
@@ -188,7 +200,7 @@ class Column:
expression.set("table", exp.to_identifier(table_name))
return Column(expression)
- def sql(self, **kwargs) -> Column:
+ def sql(self, **kwargs) -> str:
return self.expression.sql(**{"dialect": "spark", **kwargs})
def alias(self, name: str) -> Column:
@@ -265,10 +277,14 @@ class Column:
)
def like(self, other: str):
- return self.invoke_expression_over_column(self, exp.Like, expression=self._lit(other).expression)
+ return self.invoke_expression_over_column(
+ self, exp.Like, expression=self._lit(other).expression
+ )
def ilike(self, other: str):
- return self.invoke_expression_over_column(self, exp.ILike, expression=self._lit(other).expression)
+ return self.invoke_expression_over_column(
+ self, exp.ILike, expression=self._lit(other).expression
+ )
def substr(self, startPos: t.Union[int, Column], length: t.Union[int, Column]) -> Column:
startPos = self._lit(startPos) if not isinstance(startPos, Column) else startPos
@@ -287,10 +303,18 @@ class Column:
lowerBound: t.Union[ColumnOrLiteral],
upperBound: t.Union[ColumnOrLiteral],
) -> Column:
- lower_bound_exp = self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
- upper_bound_exp = self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ lower_bound_exp = (
+ self._lit(lowerBound) if not isinstance(lowerBound, Column) else lowerBound
+ )
+ upper_bound_exp = (
+ self._lit(upperBound) if not isinstance(upperBound, Column) else upperBound
+ )
return Column(
- exp.Between(this=self.column_expression, low=lower_bound_exp.expression, high=upper_bound_exp.expression)
+ exp.Between(
+ this=self.column_expression,
+ low=lower_bound_exp.expression,
+ high=upper_bound_exp.expression,
+ )
)
def over(self, window: WindowSpec) -> Column:
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 322dcf2..40cd6c9 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -21,7 +21,12 @@ from sqlglot.optimizer import optimize as optimize_func
from sqlglot.optimizer.qualify_columns import qualify_columns
if t.TYPE_CHECKING:
- from sqlglot.dataframe.sql._typing import ColumnLiterals, OutputExpressionContainer
+ from sqlglot.dataframe.sql._typing import (
+ ColumnLiterals,
+ ColumnOrLiteral,
+ ColumnOrName,
+ OutputExpressionContainer,
+ )
from sqlglot.dataframe.sql.session import SparkSession
@@ -83,7 +88,9 @@ class DataFrame:
return from_exp.alias_or_name
table_alias = from_exp.find(exp.TableAlias)
if not table_alias:
- raise RuntimeError(f"Could not find an alias name for this expression: {self.expression}")
+ raise RuntimeError(
+ f"Could not find an alias name for this expression: {self.expression}"
+ )
return table_alias.alias_or_name
return self.expression.ctes[-1].alias
@@ -132,12 +139,16 @@ class DataFrame:
cte.set("sequence_id", sequence_id or self.sequence_id)
return cte, name
- def _ensure_list_of_columns(
- self, cols: t.Union[str, t.Iterable[str], Column, t.Iterable[Column]]
- ) -> t.List[Column]:
- columns = ensure_list(cols)
- columns = Column.ensure_cols(columns)
- return columns
+ @t.overload
+ def _ensure_list_of_columns(self, cols: t.Collection[ColumnOrLiteral]) -> t.List[Column]:
+ ...
+
+ @t.overload
+ def _ensure_list_of_columns(self, cols: ColumnOrLiteral) -> t.List[Column]:
+ ...
+
+ def _ensure_list_of_columns(self, cols):
+ return Column.ensure_cols(ensure_list(cols))
def _ensure_and_normalize_cols(self, cols):
cols = self._ensure_list_of_columns(cols)
@@ -153,10 +164,16 @@ class DataFrame:
df = self._resolve_pending_hints()
sequence_id = sequence_id or df.sequence_id
expression = df.expression.copy()
- cte_expression, cte_name = df._create_cte_from_expression(expression=expression, sequence_id=sequence_id)
- new_expression = df._add_ctes_to_expression(exp.Select(), expression.ctes + [cte_expression])
+ cte_expression, cte_name = df._create_cte_from_expression(
+ expression=expression, sequence_id=sequence_id
+ )
+ new_expression = df._add_ctes_to_expression(
+ exp.Select(), expression.ctes + [cte_expression]
+ )
sel_columns = df._get_outer_select_columns(cte_expression)
- new_expression = new_expression.from_(cte_name).select(*[x.alias_or_name for x in sel_columns])
+ new_expression = new_expression.from_(cte_name).select(
+ *[x.alias_or_name for x in sel_columns]
+ )
return df.copy(expression=new_expression, sequence_id=sequence_id)
def _resolve_pending_hints(self) -> DataFrame:
@@ -169,16 +186,23 @@ class DataFrame:
hint_expression.args.get("expressions").append(hint)
df.pending_hints.remove(hint)
- join_aliases = {join_table.alias_or_name for join_table in get_tables_from_expression_with_join(expression)}
+ join_aliases = {
+ join_table.alias_or_name
+ for join_table in get_tables_from_expression_with_join(expression)
+ }
if join_aliases:
for hint in df.pending_join_hints:
for sequence_id_expression in hint.expressions:
sequence_id_or_name = sequence_id_expression.alias_or_name
sequence_ids_to_match = [sequence_id_or_name]
if sequence_id_or_name in df.spark.name_to_sequence_id_mapping:
- sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[sequence_id_or_name]
+ sequence_ids_to_match = df.spark.name_to_sequence_id_mapping[
+ sequence_id_or_name
+ ]
matching_ctes = [
- cte for cte in reversed(expression.ctes) if cte.args["sequence_id"] in sequence_ids_to_match
+ cte
+ for cte in reversed(expression.ctes)
+ if cte.args["sequence_id"] in sequence_ids_to_match
]
for matching_cte in matching_ctes:
if matching_cte.alias_or_name in join_aliases:
@@ -193,9 +217,14 @@ class DataFrame:
def _hint(self, hint_name: str, args: t.List[Column]) -> DataFrame:
hint_name = hint_name.upper()
hint_expression = (
- exp.JoinHint(this=hint_name, expressions=[exp.to_table(parameter.alias_or_name) for parameter in args])
+ exp.JoinHint(
+ this=hint_name,
+ expressions=[exp.to_table(parameter.alias_or_name) for parameter in args],
+ )
if hint_name in JOIN_HINTS
- else exp.Anonymous(this=hint_name, expressions=[parameter.expression for parameter in args])
+ else exp.Anonymous(
+ this=hint_name, expressions=[parameter.expression for parameter in args]
+ )
)
new_df = self.copy()
new_df.pending_hints.append(hint_expression)
@@ -245,7 +274,9 @@ class DataFrame:
def _get_select_expressions(
self,
) -> t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]]:
- select_expressions: t.List[t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]] = []
+ select_expressions: t.List[
+ t.Tuple[t.Union[t.Type[exp.Cache], OutputExpressionContainer], exp.Select]
+ ] = []
main_select_ctes: t.List[exp.CTE] = []
for cte in self.expression.ctes:
cache_storage_level = cte.args.get("cache_storage_level")
@@ -279,14 +310,19 @@ class DataFrame:
cache_table_name = df._create_hash_from_expression(select_expression)
cache_table = exp.to_table(cache_table_name)
original_alias_name = select_expression.args["cte_alias_name"]
- replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier(cache_table_name)
+
+ replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
+ cache_table_name
+ )
sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
exp.Literal.string(cache_storage_level),
]
- expression = exp.Cache(this=cache_table, expression=select_expression, lazy=True, options=options)
+ expression = exp.Cache(
+ this=cache_table, expression=select_expression, lazy=True, options=options
+ )
# We will drop the "view" if it exists before running the cache table
output_expressions.append(exp.Drop(this=cache_table, exists=True, kind="VIEW"))
elif expression_type == exp.Create:
@@ -305,7 +341,9 @@ class DataFrame:
raise ValueError(f"Invalid expression type: {expression_type}")
output_expressions.append(expression)
- return [expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions]
+ return [
+ expression.sql(**{"dialect": dialect, **kwargs}) for expression in output_expressions
+ ]
def copy(self, **kwargs) -> DataFrame:
return DataFrame(**object_to_dict(self, **kwargs))
@@ -317,7 +355,9 @@ class DataFrame:
if self.expression.args.get("joins"):
ambiguous_cols = [col for col in cols if not col.column_expression.table]
if ambiguous_cols:
- join_table_identifiers = [x.this for x in get_tables_from_expression_with_join(self.expression)]
+ join_table_identifiers = [
+ x.this for x in get_tables_from_expression_with_join(self.expression)
+ ]
cte_names_in_join = [x.this for x in join_table_identifiers]
for ambiguous_col in ambiguous_cols:
ctes_with_column = [
@@ -367,14 +407,20 @@ class DataFrame:
@operation(Operation.FROM)
def join(
- self, other_df: DataFrame, on: t.Union[str, t.List[str], Column, t.List[Column]], how: str = "inner", **kwargs
+ self,
+ other_df: DataFrame,
+ on: t.Union[str, t.List[str], Column, t.List[Column]],
+ how: str = "inner",
+ **kwargs,
) -> DataFrame:
other_df = other_df._convert_leaf_to_cte()
pre_join_self_latest_cte_name = self.latest_cte_name
columns = self._ensure_and_normalize_cols(on)
join_type = how.replace("_", " ")
if isinstance(columns[0].expression, exp.Column):
- join_columns = [Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns]
+ join_columns = [
+ Column(x).set_table_name(pre_join_self_latest_cte_name) for x in columns
+ ]
join_clause = functools.reduce(
lambda x, y: x & y,
[
@@ -402,7 +448,9 @@ class DataFrame:
for column in self._get_outer_select_columns(other_df)
]
column_value_mapping = {
- column.alias_or_name if not isinstance(column.expression.this, exp.Star) else column.sql(): column
+ column.alias_or_name
+ if not isinstance(column.expression.this, exp.Star)
+ else column.sql(): column
for column in other_columns + self_columns + join_columns
}
all_columns = [
@@ -410,16 +458,22 @@ class DataFrame:
for name in {x.alias_or_name: None for x in join_columns + self_columns + other_columns}
]
new_df = self.copy(
- expression=self.expression.join(other_df.latest_cte_name, on=join_clause.expression, join_type=join_type)
+ expression=self.expression.join(
+ other_df.latest_cte_name, on=join_clause.expression, join_type=join_type
+ )
+ )
+ new_df.expression = new_df._add_ctes_to_expression(
+ new_df.expression, other_df.expression.ctes
)
- new_df.expression = new_df._add_ctes_to_expression(new_df.expression, other_df.expression.ctes)
new_df.pending_hints.extend(other_df.pending_hints)
new_df = new_df.select.__wrapped__(new_df, *all_columns)
return new_df
@operation(Operation.ORDER_BY)
def orderBy(
- self, *cols: t.Union[str, Column], ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None
+ self,
+ *cols: t.Union[str, Column],
+ ascending: t.Optional[t.Union[t.Any, t.List[t.Any]]] = None,
) -> DataFrame:
"""
This implementation lets any ordered columns take priority over whatever is provided in `ascending`. Spark
@@ -429,7 +483,10 @@ class DataFrame:
columns = self._ensure_and_normalize_cols(cols)
pre_ordered_col_indexes = [
x
- for x in [i if isinstance(col.expression, exp.Ordered) else None for i, col in enumerate(columns)]
+ for x in [
+ i if isinstance(col.expression, exp.Ordered) else None
+ for i, col in enumerate(columns)
+ ]
if x is not None
]
if ascending is None:
@@ -478,7 +535,9 @@ class DataFrame:
for r_column in r_columns_unused:
l_expressions.append(exp.alias_(exp.Null(), r_column))
r_expressions.append(r_column)
- r_df = other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ r_df = (
+ other.copy()._convert_leaf_to_cte().select(*self._ensure_list_of_columns(r_expressions))
+ )
l_df = self.copy()
if allowMissingColumns:
l_df = l_df._convert_leaf_to_cte().select(*self._ensure_list_of_columns(l_expressions))
@@ -536,7 +595,9 @@ class DataFrame:
f"The minimum num nulls for dropna must be less than or equal to the number of columns. "
f"Minimum num nulls: {minimum_num_nulls}, Num Columns: {len(null_check_columns)}"
)
- if_null_checks = [F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns]
+ if_null_checks = [
+ F.when(column.isNull(), F.lit(1)).otherwise(F.lit(0)) for column in null_check_columns
+ ]
nulls_added_together = functools.reduce(lambda x, y: x + y, if_null_checks)
num_nulls = nulls_added_together.alias("num_nulls")
new_df = new_df.select(num_nulls, append=True)
@@ -576,11 +637,15 @@ class DataFrame:
value_columns = [lit(value) for value in values]
null_replacement_mapping = {
- column.alias_or_name: (F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name))
+ column.alias_or_name: (
+ F.when(column.isNull(), value).otherwise(column).alias(column.alias_or_name)
+ )
for column, value in zip(columns, value_columns)
}
null_replacement_mapping = {**all_column_mapping, **null_replacement_mapping}
- null_replacement_columns = [null_replacement_mapping[column.alias_or_name] for column in all_columns]
+ null_replacement_columns = [
+ null_replacement_mapping[column.alias_or_name] for column in all_columns
+ ]
new_df = new_df.select(*null_replacement_columns)
return new_df
@@ -589,12 +654,11 @@ class DataFrame:
self,
to_replace: t.Union[bool, int, float, str, t.List, t.Dict],
value: t.Optional[t.Union[bool, int, float, str, t.List]] = None,
- subset: t.Optional[t.Union[str, t.List[str]]] = None,
+ subset: t.Optional[t.Collection[ColumnOrName] | ColumnOrName] = None,
) -> DataFrame:
from sqlglot.dataframe.sql.functions import lit
old_values = None
- subset = ensure_list(subset)
new_df = self.copy()
all_columns = self._get_outer_select_columns(new_df.expression)
all_column_mapping = {column.alias_or_name: column for column in all_columns}
@@ -605,7 +669,9 @@ class DataFrame:
new_values = list(to_replace.values())
elif not old_values and isinstance(to_replace, list):
assert isinstance(value, list), "value must be a list since the replacements are a list"
- assert len(to_replace) == len(value), "the replacements and values must be the same length"
+ assert len(to_replace) == len(
+ value
+ ), "the replacements and values must be the same length"
old_values = to_replace
new_values = value
else:
@@ -635,7 +701,9 @@ class DataFrame:
def withColumn(self, colName: str, col: Column) -> DataFrame:
col = self._ensure_and_normalize_col(col)
existing_col_names = self.expression.named_selects
- existing_col_index = existing_col_names.index(colName) if colName in existing_col_names else None
+ existing_col_index = (
+ existing_col_names.index(colName) if colName in existing_col_names else None
+ )
if existing_col_index:
expression = self.expression.copy()
expression.expressions[existing_col_index] = col.expression
@@ -645,7 +713,11 @@ class DataFrame:
@operation(Operation.SELECT)
def withColumnRenamed(self, existing: str, new: str):
expression = self.expression.copy()
- existing_columns = [expression for expression in expression.expressions if expression.alias_or_name == existing]
+ existing_columns = [
+ expression
+ for expression in expression.expressions
+ if expression.alias_or_name == existing
+ ]
if not existing_columns:
raise ValueError("Tried to rename a column that doesn't exist")
for existing_column in existing_columns:
@@ -674,15 +746,19 @@ class DataFrame:
def hint(self, name: str, *parameters: t.Optional[t.Union[str, int]]) -> DataFrame:
parameter_list = ensure_list(parameters)
parameter_columns = (
- self._ensure_list_of_columns(parameter_list) if parameters else Column.ensure_cols([self.sequence_id])
+ self._ensure_list_of_columns(parameter_list)
+ if parameters
+ else Column.ensure_cols([self.sequence_id])
)
return self._hint(name, parameter_columns)
@operation(Operation.NO_OP)
- def repartition(self, numPartitions: t.Union[int, str], *cols: t.Union[int, str]) -> DataFrame:
- num_partitions = Column.ensure_cols(ensure_list(numPartitions))
+ def repartition(
+ self, numPartitions: t.Union[int, ColumnOrName], *cols: ColumnOrName
+ ) -> DataFrame:
+ num_partition_cols = self._ensure_list_of_columns(numPartitions)
columns = self._ensure_and_normalize_cols(cols)
- args = num_partitions + columns
+ args = num_partition_cols + columns
return self._hint("repartition", args)
@operation(Operation.NO_OP)
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index bc002e5..dbfb06f 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -45,7 +45,11 @@ def countDistinct(col: ColumnOrName, *cols: ColumnOrName) -> Column:
def when(condition: Column, value: t.Any) -> Column:
true_value = value if isinstance(value, Column) else lit(value)
- return Column(glotexp.Case(ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]))
+ return Column(
+ glotexp.Case(
+ ifs=[glotexp.If(this=condition.column_expression, true=true_value.column_expression)]
+ )
+ )
def asc(col: ColumnOrName) -> Column:
@@ -407,7 +411,9 @@ def percentile_approx(
return Column.invoke_expression_over_column(
col, glotexp.ApproxQuantile, quantile=lit(percentage), accuracy=accuracy
)
- return Column.invoke_expression_over_column(col, glotexp.ApproxQuantile, quantile=lit(percentage))
+ return Column.invoke_expression_over_column(
+ col, glotexp.ApproxQuantile, quantile=lit(percentage)
+ )
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
@@ -471,7 +477,9 @@ def factorial(col: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col, "FACTORIAL")
-def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None) -> Column:
+def lag(
+ col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[ColumnOrLiteral] = None
+) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LAG", offset, default)
if offset != 1:
@@ -479,7 +487,9 @@ def lag(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[Colu
return Column.invoke_anonymous_function(col, "LAG")
-def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None) -> Column:
+def lead(
+ col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.Any] = None
+) -> Column:
if default is not None:
return Column.invoke_anonymous_function(col, "LEAD", offset, default)
if offset != 1:
@@ -487,7 +497,9 @@ def lead(col: ColumnOrName, offset: t.Optional[int] = 1, default: t.Optional[t.A
return Column.invoke_anonymous_function(col, "LEAD")
-def nth_value(col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None) -> Column:
+def nth_value(
+ col: ColumnOrName, offset: t.Optional[int] = 1, ignoreNulls: t.Optional[bool] = None
+) -> Column:
if ignoreNulls is not None:
raise NotImplementedError("There is currently not support for `ignoreNulls` parameter")
if offset != 1:
@@ -571,7 +583,9 @@ def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Colum
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
-def months_between(date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None) -> Column:
+def months_between(
+ date1: ColumnOrName, date2: ColumnOrName, roundOff: t.Optional[bool] = None
+) -> Column:
if roundOff is None:
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2)
return Column.invoke_anonymous_function(date1, "MONTHS_BETWEEN", date2, roundOff)
@@ -611,9 +625,13 @@ def from_unixtime(col: ColumnOrName, format: t.Optional[str] = None) -> Column:
return Column.invoke_expression_over_column(col, glotexp.UnixToStr)
-def unix_timestamp(timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None) -> Column:
+def unix_timestamp(
+ timestamp: t.Optional[ColumnOrName] = None, format: t.Optional[str] = None
+) -> Column:
if format is not None:
- return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix, format=lit(format))
+ return Column.invoke_expression_over_column(
+ timestamp, glotexp.StrToUnix, format=lit(format)
+ )
return Column.invoke_expression_over_column(timestamp, glotexp.StrToUnix)
@@ -642,7 +660,9 @@ def window(
timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration), lit(startTime)
)
if slideDuration is not None:
- return Column.invoke_anonymous_function(timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration))
+ return Column.invoke_anonymous_function(
+ timeColumn, "WINDOW", lit(windowDuration), lit(slideDuration)
+ )
if startTime is not None:
return Column.invoke_anonymous_function(
timeColumn, "WINDOW", lit(windowDuration), lit(windowDuration), lit(startTime)
@@ -731,7 +751,9 @@ def trim(col: ColumnOrName) -> Column:
def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
- return Column.invoke_expression_over_column(None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols))
+ return Column.invoke_expression_over_column(
+ None, glotexp.ConcatWs, expressions=[lit(sep)] + list(cols)
+ )
def decode(col: ColumnOrName, charset: str) -> Column:
@@ -768,7 +790,9 @@ def overlay(
def sentences(
- string: ColumnOrName, language: t.Optional[ColumnOrName] = None, country: t.Optional[ColumnOrName] = None
+ string: ColumnOrName,
+ language: t.Optional[ColumnOrName] = None,
+ country: t.Optional[ColumnOrName] = None,
) -> Column:
if language is not None and country is not None:
return Column.invoke_anonymous_function(string, "SENTENCES", language, country)
@@ -794,7 +818,9 @@ def levenshtein(left: ColumnOrName, right: ColumnOrName) -> Column:
def locate(substr: str, str: ColumnOrName, pos: t.Optional[int] = None) -> Column:
substr_col = lit(substr)
if pos is not None:
- return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col, position=pos)
+ return Column.invoke_expression_over_column(
+ str, glotexp.StrPosition, substr=substr_col, position=pos
+ )
return Column.invoke_expression_over_column(str, glotexp.StrPosition, substr=substr_col)
@@ -872,7 +898,10 @@ def array(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
def create_map(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column:
cols = list(_flatten(cols)) if not isinstance(cols[0], (str, Column)) else cols # type: ignore
return Column.invoke_expression_over_column(
- None, glotexp.VarMap, keys=array(*cols[::2]).expression, values=array(*cols[1::2]).expression
+ None,
+ glotexp.VarMap,
+ keys=array(*cols[::2]).expression,
+ values=array(*cols[1::2]).expression,
)
@@ -882,29 +911,39 @@ def map_from_arrays(col1: ColumnOrName, col2: ColumnOrName) -> Column:
def array_contains(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
value_col = value if isinstance(value, Column) else lit(value)
- return Column.invoke_expression_over_column(col, glotexp.ArrayContains, expression=value_col.expression)
+ return Column.invoke_expression_over_column(
+ col, glotexp.ArrayContains, expression=value_col.expression
+ )
def arrays_overlap(col1: ColumnOrName, col2: ColumnOrName) -> Column:
return Column.invoke_anonymous_function(col1, "ARRAYS_OVERLAP", Column.ensure_col(col2))
-def slice(x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]) -> Column:
+def slice(
+ x: ColumnOrName, start: t.Union[ColumnOrName, int], length: t.Union[ColumnOrName, int]
+) -> Column:
start_col = start if isinstance(start, Column) else lit(start)
length_col = length if isinstance(length, Column) else lit(length)
return Column.invoke_anonymous_function(x, "SLICE", start_col, length_col)
-def array_join(col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None) -> Column:
+def array_join(
+ col: ColumnOrName, delimiter: str, null_replacement: t.Optional[str] = None
+) -> Column:
if null_replacement is not None:
- return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement))
+ return Column.invoke_anonymous_function(
+ col, "ARRAY_JOIN", lit(delimiter), lit(null_replacement)
+ )
return Column.invoke_anonymous_function(col, "ARRAY_JOIN", lit(delimiter))
def concat(*cols: ColumnOrName) -> Column:
if len(cols) == 1:
return Column.invoke_anonymous_function(cols[0], "CONCAT")
- return Column.invoke_anonymous_function(cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]])
+ return Column.invoke_anonymous_function(
+ cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
+ )
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
@@ -1076,7 +1115,9 @@ def map_concat(*cols: t.Union[ColumnOrName, t.Iterable[ColumnOrName]]) -> Column
return Column.invoke_anonymous_function(columns[0], "MAP_CONCAT", *columns[1:])
-def sequence(start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None) -> Column:
+def sequence(
+ start: ColumnOrName, stop: ColumnOrName, step: t.Optional[ColumnOrName] = None
+) -> Column:
if step is not None:
return Column.invoke_anonymous_function(start, "SEQUENCE", stop, step)
return Column.invoke_anonymous_function(start, "SEQUENCE", stop)
@@ -1103,12 +1144,15 @@ def aggregate(
merge_exp = _get_lambda_from_func(merge)
if finish is not None:
finish_exp = _get_lambda_from_func(finish)
- return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp))
+ return Column.invoke_anonymous_function(
+ col, "AGGREGATE", initialValue, Column(merge_exp), Column(finish_exp)
+ )
return Column.invoke_anonymous_function(col, "AGGREGATE", initialValue, Column(merge_exp))
def transform(
- col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(col, "TRANSFORM", Column(f_expression))
@@ -1124,12 +1168,17 @@ def forall(col: ColumnOrName, f: t.Callable[[Column], Column]) -> Column:
return Column.invoke_anonymous_function(col, "FORALL", Column(f_expression))
-def filter(col: ColumnOrName, f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]]) -> Column:
+def filter(
+ col: ColumnOrName,
+ f: t.Union[t.Callable[[Column], Column], t.Callable[[Column, Column], Column]],
+) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_expression_over_column(col, glotexp.ArrayFilter, expression=f_expression)
-def zip_with(left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]) -> Column:
+def zip_with(
+ left: ColumnOrName, right: ColumnOrName, f: t.Callable[[Column, Column], Column]
+) -> Column:
f_expression = _get_lambda_from_func(f)
return Column.invoke_anonymous_function(left, "ZIP_WITH", right, Column(f_expression))
@@ -1163,7 +1212,10 @@ def _lambda_quoted(value: str) -> t.Optional[bool]:
def _get_lambda_from_func(lambda_expression: t.Callable):
- variables = [glotexp.to_identifier(x, quoted=_lambda_quoted(x)) for x in lambda_expression.__code__.co_varnames]
+ variables = [
+ glotexp.to_identifier(x, quoted=_lambda_quoted(x))
+ for x in lambda_expression.__code__.co_varnames
+ ]
return glotexp.Lambda(
this=lambda_expression(*[Column(x) for x in variables]).expression,
expressions=variables,
diff --git a/sqlglot/dataframe/sql/group.py b/sqlglot/dataframe/sql/group.py
index 947aace..ba27c17 100644
--- a/sqlglot/dataframe/sql/group.py
+++ b/sqlglot/dataframe/sql/group.py
@@ -17,7 +17,9 @@ class GroupedData:
self.last_op = last_op
self.group_by_cols = group_by_cols
- def _get_function_applied_columns(self, func_name: str, cols: t.Tuple[str, ...]) -> t.List[Column]:
+ def _get_function_applied_columns(
+ self, func_name: str, cols: t.Tuple[str, ...]
+ ) -> t.List[Column]:
func_name = func_name.lower()
return [getattr(F, func_name)(name).alias(f"{func_name}({name})") for name in cols]
@@ -30,9 +32,9 @@ class GroupedData:
)
cols = self._df._ensure_and_normalize_cols(columns)
- expression = self._df.expression.group_by(*[x.expression for x in self.group_by_cols]).select(
- *[x.expression for x in self.group_by_cols + cols], append=False
- )
+ expression = self._df.expression.group_by(
+ *[x.expression for x in self.group_by_cols]
+ ).select(*[x.expression for x in self.group_by_cols + cols], append=False)
return self._df.copy(expression=expression)
def count(self) -> DataFrame:
diff --git a/sqlglot/dataframe/sql/normalize.py b/sqlglot/dataframe/sql/normalize.py
index 1513946..75feba7 100644
--- a/sqlglot/dataframe/sql/normalize.py
+++ b/sqlglot/dataframe/sql/normalize.py
@@ -23,7 +23,9 @@ def normalize(spark: SparkSession, expression_context: exp.Select, expr: t.List[
replace_branch_and_sequence_ids_with_cte_name(spark, expression_context, identifier)
-def replace_alias_name_with_cte_name(spark: SparkSession, expression_context: exp.Select, id: exp.Identifier):
+def replace_alias_name_with_cte_name(
+ spark: SparkSession, expression_context: exp.Select, id: exp.Identifier
+):
if id.alias_or_name in spark.name_to_sequence_id_mapping:
for cte in reversed(expression_context.ctes):
if cte.args["sequence_id"] in spark.name_to_sequence_id_mapping[id.alias_or_name]:
@@ -40,8 +42,12 @@ def replace_branch_and_sequence_ids_with_cte_name(
# id then it keeps that reference. This handles the weird edge case in spark that shouldn't
# be common in practice
if expression_context.args.get("joins") and id.alias_or_name in spark.known_branch_ids:
- join_table_aliases = [x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)]
- ctes_in_join = [cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases]
+ join_table_aliases = [
+ x.alias_or_name for x in get_tables_from_expression_with_join(expression_context)
+ ]
+ ctes_in_join = [
+ cte for cte in expression_context.ctes if cte.alias_or_name in join_table_aliases
+ ]
if ctes_in_join[0].args["branch_id"] == ctes_in_join[1].args["branch_id"]:
assert len(ctes_in_join) == 2
_set_alias_name(id, ctes_in_join[0].alias_or_name)
@@ -58,7 +64,6 @@ def _set_alias_name(id: exp.Identifier, name: str):
def _ensure_expressions(values: t.List[NORMALIZE_INPUT]) -> t.List[exp.Expression]:
- values = ensure_list(values)
results = []
for value in values:
if isinstance(value, str):
diff --git a/sqlglot/dataframe/sql/readwriter.py b/sqlglot/dataframe/sql/readwriter.py
index 4830035..febc664 100644
--- a/sqlglot/dataframe/sql/readwriter.py
+++ b/sqlglot/dataframe/sql/readwriter.py
@@ -19,12 +19,19 @@ class DataFrameReader:
from sqlglot.dataframe.sql.dataframe import DataFrame
sqlglot.schema.add_table(tableName)
- return DataFrame(self.spark, exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)))
+ return DataFrame(
+ self.spark,
+ exp.Select().from_(tableName).select(*sqlglot.schema.column_names(tableName)),
+ )
class DataFrameWriter:
def __init__(
- self, df: DataFrame, spark: t.Optional[SparkSession] = None, mode: t.Optional[str] = None, by_name: bool = False
+ self,
+ df: DataFrame,
+ spark: t.Optional[SparkSession] = None,
+ mode: t.Optional[str] = None,
+ by_name: bool = False,
):
self._df = df
self._spark = spark or df.spark
@@ -33,7 +40,10 @@ class DataFrameWriter:
def copy(self, **kwargs) -> DataFrameWriter:
return DataFrameWriter(
- **{k[1:] if k.startswith("_") else k: v for k, v in object_to_dict(self, **kwargs).items()}
+ **{
+ k[1:] if k.startswith("_") else k: v
+ for k, v in object_to_dict(self, **kwargs).items()
+ }
)
def sql(self, **kwargs) -> t.List[str]:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 1ea86d1..8cb16ef 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -67,13 +67,20 @@ class SparkSession:
data_expressions = [
exp.Tuple(
- expressions=list(map(lambda x: F.lit(x).expression, row if not isinstance(row, dict) else row.values()))
+ expressions=list(
+ map(
+ lambda x: F.lit(x).expression,
+ row if not isinstance(row, dict) else row.values(),
+ )
+ )
)
for row in data
]
sel_columns = [
- F.col(name).cast(data_type).alias(name).expression if data_type is not None else F.col(name).expression
+ F.col(name).cast(data_type).alias(name).expression
+ if data_type is not None
+ else F.col(name).expression
for name, data_type in column_mapping.items()
]
@@ -106,10 +113,12 @@ class SparkSession:
select_expression.set("with", expression.args.get("with"))
expression.set("with", None)
del expression.args["expression"]
- df = DataFrame(self, select_expression, output_expression_container=expression)
+ df = DataFrame(self, select_expression, output_expression_container=expression) # type: ignore
df = df._convert_leaf_to_cte()
else:
- raise ValueError("Unknown expression type provided in the SQL. Please create an issue with the SQL.")
+ raise ValueError(
+ "Unknown expression type provided in the SQL. Please create an issue with the SQL."
+ )
return df
@property
diff --git a/sqlglot/dataframe/sql/types.py b/sqlglot/dataframe/sql/types.py
index dc5c05a..a63e505 100644
--- a/sqlglot/dataframe/sql/types.py
+++ b/sqlglot/dataframe/sql/types.py
@@ -158,7 +158,11 @@ class MapType(DataType):
class StructField(DataType):
def __init__(
- self, name: str, dataType: DataType, nullable: bool = True, metadata: t.Optional[t.Dict[str, t.Any]] = None
+ self,
+ name: str,
+ dataType: DataType,
+ nullable: bool = True,
+ metadata: t.Optional[t.Dict[str, t.Any]] = None,
):
self.name = name
self.dataType = dataType
diff --git a/sqlglot/dataframe/sql/window.py b/sqlglot/dataframe/sql/window.py
index 842f366..c54c07e 100644
--- a/sqlglot/dataframe/sql/window.py
+++ b/sqlglot/dataframe/sql/window.py
@@ -74,8 +74,13 @@ class WindowSpec:
window_spec.expression.args["order"].set("expressions", order_by)
return window_spec
- def _calc_start_end(self, start: int, end: int) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
- kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {"start_side": None, "end_side": None}
+ def _calc_start_end(
+ self, start: int, end: int
+ ) -> t.Dict[str, t.Optional[t.Union[str, exp.Expression]]]:
+ kwargs: t.Dict[str, t.Optional[t.Union[str, exp.Expression]]] = {
+ "start_side": None,
+ "end_side": None,
+ }
if start == Window.currentRow:
kwargs["start"] = "CURRENT ROW"
else:
@@ -83,7 +88,9 @@ class WindowSpec:
**kwargs,
**{
"start_side": "PRECEDING",
- "start": "UNBOUNDED" if start <= Window.unboundedPreceding else F.lit(start).expression,
+ "start": "UNBOUNDED"
+ if start <= Window.unboundedPreceding
+ else F.lit(start).expression,
},
}
if end == Window.currentRow:
@@ -93,7 +100,9 @@ class WindowSpec:
**kwargs,
**{
"end_side": "FOLLOWING",
- "end": "UNBOUNDED" if end >= Window.unboundedFollowing else F.lit(end).expression,
+ "end": "UNBOUNDED"
+ if end >= Window.unboundedFollowing
+ else F.lit(end).expression,
},
}
return kwargs
@@ -103,7 +112,10 @@ class WindowSpec:
spec = self._calc_start_end(start, end)
spec["kind"] = "ROWS"
window_spec.expression.set(
- "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ "spec",
+ exp.WindowSpec(
+ **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
+ ),
)
return window_spec
@@ -112,6 +124,9 @@ class WindowSpec:
spec = self._calc_start_end(start, end)
spec["kind"] = "RANGE"
window_spec.expression.set(
- "spec", exp.WindowSpec(**{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec})
+ "spec",
+ exp.WindowSpec(
+ **{**window_spec.expression.args.get("spec", exp.WindowSpec()).args, **spec}
+ ),
)
return window_spec
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 62d042e..5bbff9d 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -1,21 +1,21 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
inline_array_sql,
no_ilike_sql,
rename_func,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _date_add(expression_class):
def func(args):
- interval = list_get(args, 1)
+ interval = seq_get(args, 1)
return expression_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=interval.this,
unit=interval.args.get("unit"),
)
@@ -23,6 +23,13 @@ def _date_add(expression_class):
return func
+def _date_trunc(args):
+ unit = seq_get(args, 1)
+ if isinstance(unit, exp.Column):
+ unit = exp.Var(this=unit.name)
+ return exp.DateTrunc(this=seq_get(args, 0), expression=unit)
+
+
def _date_add_sql(data_type, kind):
def func(self, expression):
this = self.sql(expression, "this")
@@ -40,7 +47,8 @@ def _derived_table_values_to_unnest(self, expression):
structs = []
for row in rows:
aliases = [
- exp.alias_(value, column_name) for value, column_name in zip(row, expression.args["alias"].args["columns"])
+ exp.alias_(value, column_name)
+ for value, column_name in zip(row, expression.args["alias"].args["columns"])
]
structs.append(exp.Struct(expressions=aliases))
unnest_exp = exp.Unnest(expressions=[exp.Array(expressions=structs)])
@@ -89,18 +97,19 @@ class BigQuery(Dialect):
"%j": "%-j",
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = [
(prefix + quote, quote) if prefix else quote
for quote in ["'", '"', '"""', "'''"]
for prefix in ["", "r", "R"]
]
+ COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
"CURRENT_TIME": TokenType.CURRENT_TIME,
"GEOGRAPHY": TokenType.GEOGRAPHY,
@@ -111,35 +120,40 @@ class BigQuery(Dialect):
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
}
+ KEYWORDS.pop("DIV")
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
+ "DATE_TRUNC": _date_trunc,
"DATE_ADD": _date_add(exp.DateAdd),
"DATETIME_ADD": _date_add(exp.DatetimeAdd),
+ "DIV": lambda args: exp.IntDiv(this=seq_get(args, 0), expression=seq_get(args, 1)),
"TIME_ADD": _date_add(exp.TimeAdd),
"TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
"DATE_SUB": _date_add(exp.DateSub),
"DATETIME_SUB": _date_add(exp.DatetimeSub),
"TIME_SUB": _date_add(exp.TimeSub),
"TIMESTAMP_SUB": _date_add(exp.TimestampSub),
- "PARSE_TIMESTAMP": lambda args: exp.StrToTime(this=list_get(args, 1), format=list_get(args, 0)),
+ "PARSE_TIMESTAMP": lambda args: exp.StrToTime(
+ this=seq_get(args, 1), format=seq_get(args, 0)
+ ),
}
NO_PAREN_FUNCTIONS = {
- **Parser.NO_PAREN_FUNCTIONS,
+ **parser.Parser.NO_PAREN_FUNCTIONS,
TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
TokenType.CURRENT_TIME: exp.CurrentTime,
}
NESTED_TYPE_TOKENS = {
- *Parser.NESTED_TYPE_TOKENS,
+ *parser.Parser.NESTED_TYPE_TOKENS,
TokenType.TABLE,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.DateAdd: _date_add_sql("DATE", "ADD"),
@@ -148,6 +162,7 @@ class BigQuery(Dialect):
exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})",
exp.ILike: no_ilike_sql,
+ exp.IntDiv: rename_func("DIV"),
exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})",
exp.TimeAdd: _date_add_sql("TIME", "ADD"),
exp.TimeSub: _date_add_sql("TIME", "SUB"),
@@ -157,11 +172,13 @@ class BigQuery(Dialect):
exp.Values: _derived_table_values_to_unnest,
exp.ReturnsProperty: _returnsproperty_sql,
exp.Create: _create_sql,
- exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC" if e.name == "IMMUTABLE" else "NOT DETERMINISTIC",
+ exp.VolatilityProperty: lambda self, e: f"DETERMINISTIC"
+ if e.name == "IMMUTABLE"
+ else "NOT DETERMINISTIC",
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "INT64",
exp.DataType.Type.SMALLINT: "INT64",
exp.DataType.Type.INT: "INT64",
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index f446e6d..332b4c1 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -1,8 +1,9 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser, parse_var_map
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.parser import parse_var_map
+from sqlglot.tokens import TokenType
def _lower_func(sql):
@@ -14,11 +15,12 @@ class ClickHouse(Dialect):
normalize_functions = None
null_ordering = "nulls_are_last"
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
+ COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"FINAL": TokenType.FINAL,
"DATETIME64": TokenType.DATETIME,
"INT8": TokenType.TINYINT,
@@ -30,9 +32,9 @@ class ClickHouse(Dialect):
"TUPLE": TokenType.STRUCT,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"MAP": parse_var_map,
}
@@ -44,11 +46,11 @@ class ClickHouse(Dialect):
return this
- class Generator(Generator):
+ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.NULLABLE: "Nullable",
exp.DataType.Type.DATETIME: "DateTime64",
exp.DataType.Type.MAP: "Map",
@@ -63,7 +65,7 @@ class ClickHouse(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.Array: inline_array_sql,
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py
index 9dc3c38..2498c62 100644
--- a/sqlglot/dialects/databricks.py
+++ b/sqlglot/dialects/databricks.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot.dialects.spark import Spark
@@ -15,7 +17,7 @@ class Databricks(Spark):
class Generator(Spark.Generator):
TRANSFORMS = {
- **Spark.Generator.TRANSFORMS,
+ **Spark.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
}
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 33985a7..3af08bb 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -1,8 +1,11 @@
+from __future__ import annotations
+
+import typing as t
from enum import Enum
from sqlglot import exp
from sqlglot.generator import Generator
-from sqlglot.helper import flatten, list_get
+from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
from sqlglot.tokens import Tokenizer
@@ -32,7 +35,7 @@ class Dialects(str, Enum):
class _Dialect(type):
- classes = {}
+ classes: t.Dict[str, Dialect] = {}
@classmethod
def __getitem__(cls, key):
@@ -56,19 +59,30 @@ class _Dialect(type):
klass.generator_class = getattr(klass, "Generator", Generator)
klass.quote_start, klass.quote_end = list(klass.tokenizer_class._QUOTES.items())[0]
- klass.identifier_start, klass.identifier_end = list(klass.tokenizer_class._IDENTIFIERS.items())[0]
-
- if klass.tokenizer_class._BIT_STRINGS and exp.BitString not in klass.generator_class.TRANSFORMS:
+ klass.identifier_start, klass.identifier_end = list(
+ klass.tokenizer_class._IDENTIFIERS.items()
+ )[0]
+
+ if (
+ klass.tokenizer_class._BIT_STRINGS
+ and exp.BitString not in klass.generator_class.TRANSFORMS
+ ):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
- if klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS:
+ if (
+ klass.tokenizer_class._HEX_STRINGS
+ and exp.HexString not in klass.generator_class.TRANSFORMS
+ ):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
- if klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS:
+ if (
+ klass.tokenizer_class._BYTE_STRINGS
+ and exp.ByteString not in klass.generator_class.TRANSFORMS
+ ):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
@@ -81,13 +95,13 @@ class Dialect(metaclass=_Dialect):
index_offset = 0
unnest_column_only = False
alias_post_tablesample = False
- normalize_functions = "upper"
+ normalize_functions: t.Optional[str] = "upper"
null_ordering = "nulls_are_small"
date_format = "'%Y-%m-%d'"
dateint_format = "'%Y%m%d'"
time_format = "'%Y-%m-%d %H:%M:%S'"
- time_mapping = {}
+ time_mapping: t.Dict[str, str] = {}
# autofilled
quote_start = None
@@ -167,7 +181,7 @@ class Dialect(metaclass=_Dialect):
"quote_end": self.quote_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
- "escape": self.tokenizer_class.ESCAPE,
+ "escape": self.tokenizer_class.ESCAPES[0],
"index_offset": self.index_offset,
"time_mapping": self.inverse_time_mapping,
"time_trie": self.inverse_time_trie,
@@ -195,7 +209,9 @@ def approx_count_distinct_sql(self, expression):
def if_sql(self, expression):
- expressions = self.format_args(expression.this, expression.args.get("true"), expression.args.get("false"))
+ expressions = self.format_args(
+ expression.this, expression.args.get("true"), expression.args.get("false")
+ )
return f"IF({expressions})"
@@ -298,9 +314,9 @@ def format_time_lambda(exp_class, dialect, default=None):
def _format_time(args):
return exp_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
format=Dialect[dialect].format_time(
- list_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
+ seq_get(args, 1) or (Dialect[dialect].time_format if default is True else default)
),
)
@@ -328,7 +344,9 @@ def create_with_partitions_sql(self, expression):
"expressions",
[e for e in schema.expressions if e not in partitions],
)
- prop.replace(exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions)))
+ prop.replace(
+ exp.PartitionedByProperty(this=prop.this, value=exp.Schema(expressions=partitions))
+ )
expression.set("this", schema)
return self.create_sql(expression)
@@ -337,9 +355,9 @@ def create_with_partitions_sql(self, expression):
def parse_date_delta(exp_class, unit_mapping=None):
def inner_func(args):
unit_based = len(args) == 3
- this = list_get(args, 2) if unit_based else list_get(args, 0)
- expression = list_get(args, 1) if unit_based else list_get(args, 1)
- unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY")
+ this = seq_get(args, 2) if unit_based else seq_get(args, 0)
+ expression = seq_get(args, 1) if unit_based else seq_get(args, 1)
+ unit = seq_get(args, 0) if unit_based else exp.Literal.string("DAY")
unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit
return exp_class(this=this, expression=expression, unit=unit)
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index f3ff6d3..781edff 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@@ -12,10 +14,8 @@ from sqlglot.dialects.dialect import (
rename_func,
str_position_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _unix_to_time(self, expression):
@@ -61,11 +61,14 @@ def _sort_array_sql(self, expression):
def _sort_array_reverse(args):
- return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
+ return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
def _struct_pack_sql(self, expression):
- args = [self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) for e in expression.expressions]
+ args = [
+ self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
+ for e in expression.expressions
+ ]
return f"STRUCT_PACK({', '.join(args)})"
@@ -76,15 +79,15 @@ def _datatype_sql(self, expression):
class DuckDB(Dialect):
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
":=": TokenType.EQ,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"ARRAY_LENGTH": exp.ArraySize.from_arg_list,
"ARRAY_SORT": exp.SortArray.from_arg_list,
@@ -92,7 +95,7 @@ class DuckDB(Dialect):
"EPOCH": exp.TimeToUnix.from_arg_list,
"EPOCH_MS": lambda args: exp.UnixToTime(
this=exp.Div(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=exp.Literal.number(1000),
)
),
@@ -112,11 +115,11 @@ class DuckDB(Dialect):
"UNNEST": exp.Explode.from_arg_list,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: rename_func("LIST_VALUE"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
@@ -160,7 +163,7 @@ class DuckDB(Dialect):
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
}
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 03049ff..ed7357c 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -1,4 +1,6 @@
-from sqlglot import exp, transforms
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
approx_count_distinct_sql,
@@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
var_map_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser, parse_var_map
-from sqlglot.tokens import Tokenizer
+from sqlglot.helper import seq_get
+from sqlglot.parser import parse_var_map
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@@ -34,7 +34,9 @@ def _add_date_sql(self, expression):
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
modified_increment = (
- int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression
+ int(expression.text("expression")) * multiplier
+ if expression.expression.is_number
+ else expression.expression
)
modified_increment = exp.Literal.number(modified_increment)
return f"{func}({self.format_args(expression.this, modified_increment.this)})"
@@ -165,10 +167,10 @@ class Hive(Dialect):
dateint_format = "'yyyyMMdd'"
time_format = "'yyyy-MM-dd HH:mm:ss'"
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
IDENTIFIERS = ["`"]
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
ENCODE = "utf-8"
NUMERIC_LITERALS = {
@@ -180,40 +182,44 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
- class Parser(Parser):
+ class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
unit=exp.Literal.string("DAY"),
),
"DATEDIFF": lambda args: exp.DateDiff(
- this=exp.TsOrDsToDate(this=list_get(args, 0)),
- expression=exp.TsOrDsToDate(this=list_get(args, 1)),
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=exp.Mul(
- this=list_get(args, 1),
+ this=seq_get(args, 1),
expression=exp.Literal.number(-1),
),
unit=exp.Literal.string("DAY"),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
- "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
"LOCATE": lambda args: exp.StrPosition(
- this=list_get(args, 1),
- substr=list_get(args, 0),
- position=list_get(args, 2),
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ ),
+ "LOG": (
+ lambda args: exp.Log.from_arg_list(args)
+ if len(args) > 1
+ else exp.Ln.from_arg_list(args)
),
- "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)),
"MAP": parse_var_map,
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
@@ -226,15 +232,16 @@ class Hive(Dialect):
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.VARBINARY: "BINARY",
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
- **transforms.UNALIAS_GROUP,
+ **generator.Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP, # type: ignore
exp.AnonymousProperty: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayAgg: rename_func("COLLECT_LIST"),
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 524390f..e742640 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -1,4 +1,8 @@
-from sqlglot import exp
+from __future__ import annotations
+
+import typing as t
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
no_ilike_sql,
@@ -6,42 +10,47 @@ from sqlglot.dialects.dialect import (
no_tablesample_sql,
no_trycast_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
+
+
+def _show_parser(*args, **kwargs):
+ def _parse(self):
+ return self._parse_show_mysql(*args, **kwargs)
+
+ return _parse
def _date_trunc_sql(self, expression):
- unit = expression.text("unit").lower()
+ unit = expression.name.lower()
- this = self.sql(expression.this)
+ expr = self.sql(expression.expression)
if unit == "day":
- return f"DATE({this})"
+ return f"DATE({expr})"
if unit == "week":
- concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
+ concat = f"CONCAT(YEAR({expr}), ' ', WEEK({expr}, 1), ' 1')"
date_format = "%Y %u %w"
elif unit == "month":
- concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
+ concat = f"CONCAT(YEAR({expr}), ' ', MONTH({expr}), ' 1')"
date_format = "%Y %c %e"
elif unit == "quarter":
- concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
+ concat = f"CONCAT(YEAR({expr}), ' ', QUARTER({expr}) * 3 - 2, ' 1')"
date_format = "%Y %c %e"
elif unit == "year":
- concat = f"CONCAT(YEAR({this}), ' 1 1')"
+ concat = f"CONCAT(YEAR({expr}), ' 1 1')"
date_format = "%Y %c %e"
else:
self.unsupported("Unexpected interval unit: {unit}")
- return f"DATE({this})"
+ return f"DATE({expr})"
return f"STR_TO_DATE({concat}, '{date_format}')"
def _str_to_date(args):
- date_format = MySQL.format_time(list_get(args, 1))
- return exp.StrToDate(this=list_get(args, 0), format=date_format)
+ date_format = MySQL.format_time(seq_get(args, 1))
+ return exp.StrToDate(this=seq_get(args, 0), format=date_format)
def _str_to_date_sql(self, expression):
@@ -66,9 +75,9 @@ def _trim_sql(self, expression):
def _date_add(expression_class):
def func(args):
- interval = list_get(args, 1)
+ interval = seq_get(args, 1)
return expression_class(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
expression=interval.this,
unit=exp.Literal.string(interval.text("unit").lower()),
)
@@ -101,15 +110,16 @@ class MySQL(Dialect):
"%l": "%-I",
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", '"']
COMMENTS = ["--", "#", ("/*", "*/")]
IDENTIFIERS = ["`"]
+ ESCAPES = ["'", "\\"]
BIT_STRINGS = [("b'", "'"), ("B'", "'"), ("0b", "")]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", "")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@@ -156,20 +166,23 @@ class MySQL(Dialect):
"_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER,
+ "@@": TokenType.SESSION_PARAMETER,
}
- class Parser(Parser):
+ COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
+
+ class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"GROUP_CONCAT": lambda self: self.expression(
exp.GroupConcat,
this=self._parse_lambda(),
@@ -178,15 +191,212 @@ class MySQL(Dialect):
}
PROPERTY_PARSERS = {
- **Parser.PROPERTY_PARSERS,
+ **parser.Parser.PROPERTY_PARSERS,
TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty),
}
- class Generator(Generator):
+ STATEMENT_PARSERS = {
+ **parser.Parser.STATEMENT_PARSERS,
+ TokenType.SHOW: lambda self: self._parse_show(),
+ TokenType.SET: lambda self: self._parse_set(),
+ }
+
+ SHOW_PARSERS = {
+ "BINARY LOGS": _show_parser("BINARY LOGS"),
+ "MASTER LOGS": _show_parser("BINARY LOGS"),
+ "BINLOG EVENTS": _show_parser("BINLOG EVENTS"),
+ "CHARACTER SET": _show_parser("CHARACTER SET"),
+ "CHARSET": _show_parser("CHARACTER SET"),
+ "COLLATION": _show_parser("COLLATION"),
+ "FULL COLUMNS": _show_parser("COLUMNS", target="FROM", full=True),
+ "COLUMNS": _show_parser("COLUMNS", target="FROM"),
+ "CREATE DATABASE": _show_parser("CREATE DATABASE", target=True),
+ "CREATE EVENT": _show_parser("CREATE EVENT", target=True),
+ "CREATE FUNCTION": _show_parser("CREATE FUNCTION", target=True),
+ "CREATE PROCEDURE": _show_parser("CREATE PROCEDURE", target=True),
+ "CREATE TABLE": _show_parser("CREATE TABLE", target=True),
+ "CREATE TRIGGER": _show_parser("CREATE TRIGGER", target=True),
+ "CREATE VIEW": _show_parser("CREATE VIEW", target=True),
+ "DATABASES": _show_parser("DATABASES"),
+ "ENGINE": _show_parser("ENGINE", target=True),
+ "STORAGE ENGINES": _show_parser("ENGINES"),
+ "ENGINES": _show_parser("ENGINES"),
+ "ERRORS": _show_parser("ERRORS"),
+ "EVENTS": _show_parser("EVENTS"),
+ "FUNCTION CODE": _show_parser("FUNCTION CODE", target=True),
+ "FUNCTION STATUS": _show_parser("FUNCTION STATUS"),
+ "GRANTS": _show_parser("GRANTS", target="FOR"),
+ "INDEX": _show_parser("INDEX", target="FROM"),
+ "MASTER STATUS": _show_parser("MASTER STATUS"),
+ "OPEN TABLES": _show_parser("OPEN TABLES"),
+ "PLUGINS": _show_parser("PLUGINS"),
+ "PROCEDURE CODE": _show_parser("PROCEDURE CODE", target=True),
+ "PROCEDURE STATUS": _show_parser("PROCEDURE STATUS"),
+ "PRIVILEGES": _show_parser("PRIVILEGES"),
+ "FULL PROCESSLIST": _show_parser("PROCESSLIST", full=True),
+ "PROCESSLIST": _show_parser("PROCESSLIST"),
+ "PROFILE": _show_parser("PROFILE"),
+ "PROFILES": _show_parser("PROFILES"),
+ "RELAYLOG EVENTS": _show_parser("RELAYLOG EVENTS"),
+ "REPLICAS": _show_parser("REPLICAS"),
+ "SLAVE HOSTS": _show_parser("REPLICAS"),
+ "REPLICA STATUS": _show_parser("REPLICA STATUS"),
+ "SLAVE STATUS": _show_parser("REPLICA STATUS"),
+ "GLOBAL STATUS": _show_parser("STATUS", global_=True),
+ "SESSION STATUS": _show_parser("STATUS"),
+ "STATUS": _show_parser("STATUS"),
+ "TABLE STATUS": _show_parser("TABLE STATUS"),
+ "FULL TABLES": _show_parser("TABLES", full=True),
+ "TABLES": _show_parser("TABLES"),
+ "TRIGGERS": _show_parser("TRIGGERS"),
+ "GLOBAL VARIABLES": _show_parser("VARIABLES", global_=True),
+ "SESSION VARIABLES": _show_parser("VARIABLES"),
+ "VARIABLES": _show_parser("VARIABLES"),
+ "WARNINGS": _show_parser("WARNINGS"),
+ }
+
+ SET_PARSERS = {
+ "GLOBAL": lambda self: self._parse_set_item_assignment("GLOBAL"),
+ "PERSIST": lambda self: self._parse_set_item_assignment("PERSIST"),
+ "PERSIST_ONLY": lambda self: self._parse_set_item_assignment("PERSIST_ONLY"),
+ "SESSION": lambda self: self._parse_set_item_assignment("SESSION"),
+ "LOCAL": lambda self: self._parse_set_item_assignment("LOCAL"),
+ "CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
+ "CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
+ "NAMES": lambda self: self._parse_set_item_names(),
+ }
+
+ PROFILE_TYPES = {
+ "ALL",
+ "BLOCK IO",
+ "CONTEXT SWITCHES",
+ "CPU",
+ "IPC",
+ "MEMORY",
+ "PAGE FAULTS",
+ "SOURCE",
+ "SWAPS",
+ }
+
+ def _parse_show_mysql(self, this, target=False, full=None, global_=None):
+ if target:
+ if isinstance(target, str):
+ self._match_text(target)
+ target_id = self._parse_id_var()
+ else:
+ target_id = None
+
+ log = self._parse_string() if self._match_text("IN") else None
+
+ if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
+ position = self._parse_number() if self._match_text("FROM") else None
+ db = None
+ else:
+ position = None
+ db = self._parse_id_var() if self._match_text("FROM") else None
+
+ channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
+
+ like = self._parse_string() if self._match_text("LIKE") else None
+ where = self._parse_where()
+
+ if this == "PROFILE":
+ types = self._parse_csv(self._parse_show_profile_type)
+ query = self._parse_number() if self._match_text("FOR", "QUERY") else None
+ offset = self._parse_number() if self._match_text("OFFSET") else None
+ limit = self._parse_number() if self._match_text("LIMIT") else None
+ else:
+ types, query = None, None
+ offset, limit = self._parse_oldstyle_limit()
+
+ mutex = True if self._match_text("MUTEX") else None
+ mutex = False if self._match_text("STATUS") else mutex
+
+ return self.expression(
+ exp.Show,
+ this=this,
+ target=target_id,
+ full=full,
+ log=log,
+ position=position,
+ db=db,
+ channel=channel,
+ like=like,
+ where=where,
+ types=types,
+ query=query,
+ offset=offset,
+ limit=limit,
+ mutex=mutex,
+ **{"global": global_},
+ )
+
+ def _parse_show_profile_type(self):
+ for type_ in self.PROFILE_TYPES:
+ if self._match_text(*type_.split(" ")):
+ return exp.Var(this=type_)
+ return None
+
+ def _parse_oldstyle_limit(self):
+ limit = None
+ offset = None
+ if self._match_text("LIMIT"):
+ parts = self._parse_csv(self._parse_number)
+ if len(parts) == 1:
+ limit = parts[0]
+ elif len(parts) == 2:
+ limit = parts[1]
+ offset = parts[0]
+ return offset, limit
+
+ def _default_parse_set_item(self):
+ return self._parse_set_item_assignment(kind=None)
+
+ def _parse_set_item_assignment(self, kind):
+ left = self._parse_primary() or self._parse_id_var()
+ if not self._match(TokenType.EQ):
+ self.raise_error("Expected =")
+ right = self._parse_statement() or self._parse_id_var()
+
+ this = self.expression(
+ exp.EQ,
+ this=left,
+ expression=right,
+ )
+
+ return self.expression(
+ exp.SetItem,
+ this=this,
+ kind=kind,
+ )
+
+ def _parse_set_item_charset(self, kind):
+ this = self._parse_string() or self._parse_id_var()
+
+ return self.expression(
+ exp.SetItem,
+ this=this,
+ kind=kind,
+ )
+
+ def _parse_set_item_names(self):
+ charset = self._parse_string() or self._parse_id_var()
+ if self._match_text("COLLATE"):
+ collate = self._parse_string() or self._parse_id_var()
+ else:
+ collate = None
+ return self.expression(
+ exp.SetItem,
+ this=charset,
+ collate=collate,
+ kind="NAMES",
+ )
+
+ class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.ILike: no_ilike_sql,
@@ -199,6 +409,8 @@ class MySQL(Dialect):
exp.StrToDate: _str_to_date_sql,
exp.StrToTime: _str_to_date_sql,
exp.Trim: _trim_sql,
+ exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
+ exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
}
ROOT_PROPERTIES = {
@@ -209,4 +421,69 @@ class MySQL(Dialect):
exp.SchemaCommentProperty,
}
- WITH_PROPERTIES = {}
+ WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
+
+ def show_sql(self, expression):
+ this = f" {expression.name}"
+ full = " FULL" if expression.args.get("full") else ""
+ global_ = " GLOBAL" if expression.args.get("global") else ""
+
+ target = self.sql(expression, "target")
+ target = f" {target}" if target else ""
+ if expression.name in {"COLUMNS", "INDEX"}:
+ target = f" FROM{target}"
+ elif expression.name == "GRANTS":
+ target = f" FOR{target}"
+
+ db = self._prefixed_sql("FROM", expression, "db")
+
+ like = self._prefixed_sql("LIKE", expression, "like")
+ where = self.sql(expression, "where")
+
+ types = self.expressions(expression, key="types")
+ types = f" {types}" if types else types
+ query = self._prefixed_sql("FOR QUERY", expression, "query")
+
+ if expression.name == "PROFILE":
+ offset = self._prefixed_sql("OFFSET", expression, "offset")
+ limit = self._prefixed_sql("LIMIT", expression, "limit")
+ else:
+ offset = ""
+ limit = self._oldstyle_limit_sql(expression)
+
+ log = self._prefixed_sql("IN", expression, "log")
+ position = self._prefixed_sql("FROM", expression, "position")
+
+ channel = self._prefixed_sql("FOR CHANNEL", expression, "channel")
+
+ if expression.name == "ENGINE":
+ mutex_or_status = " MUTEX" if expression.args.get("mutex") else " STATUS"
+ else:
+ mutex_or_status = ""
+
+ return f"SHOW{full}{global_}{this}{target}{types}{db}{query}{log}{position}{channel}{mutex_or_status}{like}{where}{offset}{limit}"
+
+ def _prefixed_sql(self, prefix, expression, arg):
+ sql = self.sql(expression, arg)
+ if not sql:
+ return ""
+ return f" {prefix} {sql}"
+
+ def _oldstyle_limit_sql(self, expression):
+ limit = self.sql(expression, "limit")
+ offset = self.sql(expression, "offset")
+ if limit:
+ limit_offset = f"{offset}, {limit}" if offset else limit
+ return f" LIMIT {limit_offset}"
+ return ""
+
+ def setitem_sql(self, expression):
+ kind = self.sql(expression, "kind")
+ kind = f"{kind} " if kind else ""
+ this = self.sql(expression, "this")
+ collate = self.sql(expression, "collate")
+ collate = f" COLLATE {collate}" if collate else ""
+ return f"{kind}{this}{collate}"
+
+ def set_sql(self, expression):
+ return f"SET {self.expressions(expression)}"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 144dba5..3bc1109 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -1,8 +1,9 @@
-from sqlglot import exp, transforms
+from __future__ import annotations
+
+from sqlglot import exp, generator, tokens, transforms
from sqlglot.dialects.dialect import Dialect, no_ilike_sql
-from sqlglot.generator import Generator
from sqlglot.helper import csv
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
def _limit_sql(self, expression):
@@ -36,9 +37,9 @@ class Oracle(Dialect):
"YYYY": "%Y", # 2015
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "NUMBER",
exp.DataType.Type.SMALLINT: "NUMBER",
exp.DataType.Type.INT: "NUMBER",
@@ -49,11 +50,12 @@ class Oracle(Dialect):
exp.DataType.Type.NVARCHAR: "NVARCHAR2",
exp.DataType.Type.TEXT: "CLOB",
exp.DataType.Type.BINARY: "BLOB",
+ exp.DataType.Type.VARBINARY: "BLOB",
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
- **transforms.UNALIAS_GROUP,
+ **generator.Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP, # type: ignore
exp.ILike: no_ilike_sql,
exp.Limit: _limit_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@@ -86,9 +88,9 @@ class Oracle(Dialect):
def table_sql(self, expression):
return super().table_sql(expression, sep=" ")
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 459e926..553a73b 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -9,9 +11,7 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
@@ -160,12 +160,12 @@ class Postgres(Dialect):
"YYYY": "%Y", # 2015
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
"COMMENT ON": TokenType.COMMENT_ON,
@@ -179,31 +179,32 @@ class Postgres(Dialect):
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
- **Tokenizer.SINGLE_TOKENS,
+ **tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
STRICT_CAST = False
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
+ exp.DataType.Type.VARBINARY: "BYTEA",
exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index a2d392c..11ea778 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -1,4 +1,6 @@
-from sqlglot import exp, transforms
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@@ -10,10 +12,8 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _approx_distinct_sql(self, expression):
@@ -110,30 +110,29 @@ class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
time_format = "'%Y-%m-%d %H:%i:%S'"
- time_mapping = MySQL.time_mapping
+ time_mapping = MySQL.time_mapping # type: ignore
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
- **Tokenizer.KEYWORDS,
- "VARBINARY": TokenType.BINARY,
+ **tokens.Tokenizer.KEYWORDS,
"ROW": TokenType.STRUCT,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
"CARDINALITY": exp.ArraySize.from_arg_list,
"CONTAINS": exp.ArrayContains.from_arg_list,
"DATE_ADD": lambda args: exp.DateAdd(
- this=list_get(args, 2),
- expression=list_get(args, 1),
- unit=list_get(args, 0),
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
),
"DATE_DIFF": lambda args: exp.DateDiff(
- this=list_get(args, 2),
- expression=list_get(args, 1),
- unit=list_get(args, 0),
+ this=seq_get(args, 2),
+ expression=seq_get(args, 1),
+ unit=seq_get(args, 0),
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
@@ -143,7 +142,7 @@ class Presto(Dialect):
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@@ -159,7 +158,7 @@ class Presto(Dialect):
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.BINARY: "VARBINARY",
@@ -169,8 +168,8 @@ class Presto(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
- **transforms.UNALIAS_GROUP,
+ **generator.Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP, # type: ignore
exp.ApproxDistinct: _approx_distinct_sql,
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.ArrayConcat: rename_func("CONCAT"),
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index e1f7b78..a9b12fb 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType
@@ -6,29 +8,30 @@ from sqlglot.tokens import TokenType
class Redshift(Postgres):
time_format = "'YYYY-MM-DD HH:MI:SS'"
time_mapping = {
- **Postgres.time_mapping,
+ **Postgres.time_mapping, # type: ignore
"MON": "%b",
"HH": "%H",
}
class Tokenizer(Postgres.Tokenizer):
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
KEYWORDS = {
- **Postgres.Tokenizer.KEYWORDS,
+ **Postgres.Tokenizer.KEYWORDS, # type: ignore
"GEOMETRY": TokenType.GEOMETRY,
"GEOGRAPHY": TokenType.GEOGRAPHY,
"HLLSKETCH": TokenType.HLLSKETCH,
"SUPER": TokenType.SUPER,
"TIME": TokenType.TIMESTAMP,
"TIMETZ": TokenType.TIMESTAMPTZ,
- "VARBYTE": TokenType.BINARY,
+ "VARBYTE": TokenType.VARBINARY,
"SIMILAR TO": TokenType.SIMILAR_TO,
}
class Generator(Postgres.Generator):
TYPE_MAPPING = {
- **Postgres.Generator.TYPE_MAPPING,
+ **Postgres.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.BINARY: "VARBYTE",
+ exp.DataType.Type.VARBINARY: "VARBYTE",
exp.DataType.Type.INT: "INTEGER",
}
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 3b97e6d..d1aaded 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
format_time_lambda,
@@ -6,10 +8,8 @@ from sqlglot.dialects.dialect import (
rename_func,
)
from sqlglot.expressions import Literal
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.helper import seq_get
+from sqlglot.tokens import TokenType
def _check_int(s):
@@ -28,7 +28,9 @@ def _snowflake_to_timestamp(args):
# case: <numeric_expr> [ , <scale> ]
if second_arg.name not in ["0", "3", "9"]:
- raise ValueError(f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9")
+ raise ValueError(
+ f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
+ )
if second_arg.name == "0":
timescale = exp.UnixToTime.SECONDS
@@ -39,7 +41,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime(this=first_arg, scale=timescale)
- first_arg = list_get(args, 0)
+ first_arg = seq_get(args, 0)
if not isinstance(first_arg, Literal):
# case: <variant_expr>
return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
@@ -56,7 +58,7 @@ def _snowflake_to_timestamp(args):
return exp.UnixToTime.from_arg_list(args)
-def _unix_to_time(self, expression):
+def _unix_to_time_sql(self, expression):
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
if scale in [None, exp.UnixToTime.SECONDS]:
@@ -132,9 +134,9 @@ class Snowflake(Dialect):
"ff6": "%f",
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"ARRAYAGG": exp.ArrayAgg.from_arg_list,
"IFF": exp.If.from_arg_list,
"TO_TIMESTAMP": _snowflake_to_timestamp,
@@ -143,18 +145,18 @@ class Snowflake(Dialect):
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
FUNC_TOKENS = {
- *Parser.FUNC_TOKENS,
+ *parser.Parser.FUNC_TOKENS,
TokenType.RLIKE,
TokenType.TABLE,
}
COLUMN_OPERATORS = {
- **Parser.COLUMN_OPERATORS,
+ **parser.Parser.COLUMN_OPERATORS, # type: ignore
TokenType.COLON: lambda self, this, path: self.expression(
exp.Bracket,
this=this,
@@ -163,21 +165,21 @@ class Snowflake(Dialect):
}
PROPERTY_PARSERS = {
- **Parser.PROPERTY_PARSERS,
+ **parser.Parser.PROPERTY_PARSERS,
TokenType.PARTITION_BY: lambda self: self._parse_partitioned_by(),
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
- ESCAPE = "\\"
+ ESCAPES = ["\\"]
SINGLE_TOKENS = {
- **Tokenizer.SINGLE_TOKENS,
+ **tokens.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
}
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"QUALIFY": TokenType.QUALIFY,
"DOUBLE PRECISION": TokenType.DOUBLE,
"TIMESTAMP_LTZ": TokenType.TIMESTAMPLTZ,
@@ -187,15 +189,15 @@ class Snowflake(Dialect):
"SAMPLE": TokenType.TABLE_SAMPLE,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
CREATE_TRANSIENT = True
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.If: rename_func("IFF"),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
- exp.UnixToTime: _unix_to_time,
+ exp.UnixToTime: _unix_to_time_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.Array: inline_array_sql,
exp.StrPosition: rename_func("POSITION"),
@@ -204,7 +206,7 @@ class Snowflake(Dialect):
}
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
}
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 572f411..4e404b8 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -1,8 +1,9 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, parser
from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func
from sqlglot.dialects.hive import Hive
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
+from sqlglot.helper import seq_get
def _create_sql(self, e):
@@ -46,36 +47,36 @@ def _unix_to_time(self, expression):
class Spark(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
- **Hive.Parser.FUNCTIONS,
+ **Hive.Parser.FUNCTIONS, # type: ignore
"MAP_FROM_ARRAYS": exp.Map.from_arg_list,
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"LEFT": lambda args: exp.Substring(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
start=exp.Literal.number(1),
- length=list_get(args, 1),
+ length=seq_get(args, 1),
),
"SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
),
"SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=list_get(args, 0),
- expression=list_get(args, 1),
+ this=seq_get(args, 0),
+ expression=seq_get(args, 1),
),
"RIGHT": lambda args: exp.Substring(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
start=exp.Sub(
- this=exp.Length(this=list_get(args, 0)),
- expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)),
+ this=exp.Length(this=seq_get(args, 0)),
+ expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)),
),
- length=list_get(args, 1),
+ length=seq_get(args, 1),
),
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"IIF": exp.If.from_arg_list,
}
FUNCTION_PARSERS = {
- **Parser.FUNCTION_PARSERS,
+ **parser.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -88,14 +89,14 @@ class Spark(Hive):
class Generator(Hive.Generator):
TYPE_MAPPING = {
- **Hive.Generator.TYPE_MAPPING,
+ **Hive.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "BYTE",
exp.DataType.Type.SMALLINT: "SHORT",
exp.DataType.Type.BIGINT: "LONG",
}
TRANSFORMS = {
- **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}},
+ **Hive.Generator.TRANSFORMS, # type: ignore
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}",
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
@@ -114,6 +115,8 @@ class Spark(Hive):
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
}
+ TRANSFORMS.pop(exp.ArraySort)
+ TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 62b7617..8c9fb76 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -1,4 +1,6 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
arrow_json_extract_scalar_sql,
@@ -8,31 +10,28 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
rename_func,
)
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
class SQLite(Dialect):
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
HEX_STRINGS = [("x'", "'"), ("X'", "'"), ("0x", ""), ("0X", "")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
- "VARBINARY": TokenType.BINARY,
+ **tokens.Tokenizer.KEYWORDS,
"AUTOINCREMENT": TokenType.AUTO_INCREMENT,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"EDITDIST3": exp.Levenshtein.from_arg_list,
}
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
exp.DataType.Type.TINYINT: "INTEGER",
exp.DataType.Type.SMALLINT: "INTEGER",
@@ -46,6 +45,7 @@ class SQLite(Dialect):
exp.DataType.Type.VARCHAR: "TEXT",
exp.DataType.Type.NVARCHAR: "TEXT",
exp.DataType.Type.BINARY: "BLOB",
+ exp.DataType.Type.VARBINARY: "BLOB",
}
TOKEN_MAPPING = {
@@ -53,7 +53,7 @@ class SQLite(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
index 0cba6fe..3519c09 100644
--- a/sqlglot/dialects/starrocks.py
+++ b/sqlglot/dialects/starrocks.py
@@ -1,10 +1,12 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.dialect import arrow_json_extract_sql, rename_func
from sqlglot.dialects.mysql import MySQL
class StarRocks(MySQL):
- class Generator(MySQL.Generator):
+ class Generator(MySQL.Generator): # type: ignore
TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
exp.DataType.Type.TEXT: "STRING",
@@ -13,7 +15,7 @@ class StarRocks(MySQL):
}
TRANSFORMS = {
- **MySQL.Generator.TRANSFORMS,
+ **MySQL.Generator.TRANSFORMS, # type: ignore
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.DateDiff: rename_func("DATEDIFF"),
@@ -22,3 +24,4 @@ class StarRocks(MySQL):
exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({self.sql(e, 'this')}, {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
}
+ TRANSFORMS.pop(exp.DateTrunc)
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
index 45aa041..63e7275 100644
--- a/sqlglot/dialects/tableau.py
+++ b/sqlglot/dialects/tableau.py
@@ -1,7 +1,7 @@
-from sqlglot import exp
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
-from sqlglot.generator import Generator
-from sqlglot.parser import Parser
def _if_sql(self, expression):
@@ -20,17 +20,17 @@ def _count_sql(self, expression):
class Tableau(Dialect):
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.If: _if_sql,
exp.Coalesce: _coalesce_sql,
exp.Count: _count_sql,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"IFNULL": exp.Coalesce.from_arg_list,
"COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)),
}
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
index 9a6f7fe..c7b34fe 100644
--- a/sqlglot/dialects/trino.py
+++ b/sqlglot/dialects/trino.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from sqlglot import exp
from sqlglot.dialects.presto import Presto
@@ -5,7 +7,7 @@ from sqlglot.dialects.presto import Presto
class Trino(Presto):
class Generator(Presto.Generator):
TRANSFORMS = {
- **Presto.Generator.TRANSFORMS,
+ **Presto.Generator.TRANSFORMS, # type: ignore
exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
}
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 0f93c75..a233d4b 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -1,15 +1,22 @@
+from __future__ import annotations
+
import re
-from sqlglot import exp
+from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func
from sqlglot.expressions import DataType
-from sqlglot.generator import Generator
-from sqlglot.helper import list_get
-from sqlglot.parser import Parser
+from sqlglot.helper import seq_get
from sqlglot.time import format_time
-from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.tokens import TokenType
-FULL_FORMAT_TIME_MAPPING = {"weekday": "%A", "dw": "%A", "w": "%A", "month": "%B", "mm": "%B", "m": "%B"}
+FULL_FORMAT_TIME_MAPPING = {
+ "weekday": "%A",
+ "dw": "%A",
+ "w": "%A",
+ "month": "%B",
+ "mm": "%B",
+ "m": "%B",
+}
DATE_DELTA_INTERVAL = {
"year": "year",
"yyyy": "year",
@@ -37,11 +44,13 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def _format_time(args):
return exp_class(
- this=list_get(args, 1),
+ this=seq_get(args, 1),
format=exp.Literal.string(
format_time(
- list_get(args, 0).name or (TSQL.time_format if default is True else default),
- {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING} if full_format_mapping else TSQL.time_mapping,
+ seq_get(args, 0).name or (TSQL.time_format if default is True else default),
+ {**TSQL.time_mapping, **FULL_FORMAT_TIME_MAPPING}
+ if full_format_mapping
+ else TSQL.time_mapping,
)
),
)
@@ -50,12 +59,12 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None):
def parse_format(args):
- fmt = list_get(args, 1)
+ fmt = seq_get(args, 1)
number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this)
if number_fmt:
- return exp.NumberToStr(this=list_get(args, 0), format=fmt)
+ return exp.NumberToStr(this=seq_get(args, 0), format=fmt)
return exp.TimeToStr(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
format=exp.Literal.string(
format_time(fmt.name, TSQL.format_time_mapping)
if len(fmt.name) == 1
@@ -188,11 +197,11 @@ class TSQL(Dialect):
"Y": "%a %Y",
}
- class Tokenizer(Tokenizer):
+ class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]
KEYWORDS = {
- **Tokenizer.KEYWORDS,
+ **tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
"REAL": TokenType.FLOAT,
"NTEXT": TokenType.TEXT,
@@ -200,7 +209,6 @@ class TSQL(Dialect):
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
"TIME": TokenType.TIMESTAMP,
- "VARBINARY": TokenType.BINARY,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
"SMALLMONEY": TokenType.SMALLMONEY,
@@ -213,9 +221,9 @@ class TSQL(Dialect):
"TOP": TokenType.TOP,
}
- class Parser(Parser):
+ class Parser(parser.Parser):
FUNCTIONS = {
- **Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS,
"CHARINDEX": exp.StrPosition.from_arg_list,
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
@@ -243,14 +251,16 @@ class TSQL(Dialect):
this = self._parse_column()
# Retrieve length of datatype and override to default if not specified
- if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
+ if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
to = exp.DataType.build(to.this, expressions=[exp.Literal.number(30)], nested=False)
# Check whether a conversion with format is applicable
if self._match(TokenType.COMMA):
format_val = self._parse_number().name
if format_val not in TSQL.convert_format_mapping:
- raise ValueError(f"CONVERT function at T-SQL does not support format style {format_val}")
+ raise ValueError(
+ f"CONVERT function at T-SQL does not support format style {format_val}"
+ )
format_norm = exp.Literal.string(TSQL.convert_format_mapping[format_val])
# Check whether the convert entails a string to date format
@@ -272,9 +282,9 @@ class TSQL(Dialect):
# Entails a simple cast without any format requirement
return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
- class Generator(Generator):
+ class Generator(generator.Generator):
TYPE_MAPPING = {
- **Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "BIT",
exp.DataType.Type.INT: "INTEGER",
exp.DataType.Type.DECIMAL: "NUMERIC",
@@ -283,7 +293,7 @@ class TSQL(Dialect):
}
TRANSFORMS = {
- **Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.DateAdd: generate_date_delta_with_unit_sql,
exp.DateDiff: generate_date_delta_with_unit_sql,
exp.CurrentDate: rename_func("GETDATE"),
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 0567c12..2d959ab 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -4,7 +4,7 @@ from heapq import heappop, heappush
from sqlglot import Dialect
from sqlglot import expressions as exp
-from sqlglot.helper import ensure_list
+from sqlglot.helper import ensure_collection
@dataclass(frozen=True)
@@ -116,7 +116,9 @@ class ChangeDistiller:
source_node = self._source_index[kept_source_node_id]
target_node = self._target_index[kept_target_node_id]
if not isinstance(source_node, LEAF_EXPRESSION_TYPES) or source_node == target_node:
- edit_script.extend(self._generate_move_edits(source_node, target_node, matching_set))
+ edit_script.extend(
+ self._generate_move_edits(source_node, target_node, matching_set)
+ )
edit_script.append(Keep(source_node, target_node))
else:
edit_script.append(Update(source_node, target_node))
@@ -158,13 +160,16 @@ class ChangeDistiller:
max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
if max_leaves_num:
common_leaves_num = sum(
- 1 if s in source_leaf_ids and t in target_leaf_ids else 0 for s, t in leaves_matching_set
+ 1 if s in source_leaf_ids and t in target_leaf_ids else 0
+ for s, t in leaves_matching_set
)
leaf_similarity_score = common_leaves_num / max_leaves_num
else:
leaf_similarity_score = 0.0
- adjusted_t = self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
+ adjusted_t = (
+ self.t if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 else 0.4
+ )
if leaf_similarity_score >= 0.8 or (
leaf_similarity_score >= adjusted_t
@@ -201,7 +206,10 @@ class ChangeDistiller:
matching_set = set()
while candidate_matchings:
_, _, source_leaf, target_leaf = heappop(candidate_matchings)
- if id(source_leaf) in self._unmatched_source_nodes and id(target_leaf) in self._unmatched_target_nodes:
+ if (
+ id(source_leaf) in self._unmatched_source_nodes
+ and id(target_leaf) in self._unmatched_target_nodes
+ ):
matching_set.add((id(source_leaf), id(target_leaf)))
self._unmatched_source_nodes.remove(id(source_leaf))
self._unmatched_target_nodes.remove(id(target_leaf))
@@ -241,8 +249,7 @@ def _get_leaves(expression):
has_child_exprs = False
for a in expression.args.values():
- nodes = ensure_list(a)
- for node in nodes:
+ for node in ensure_collection(a):
if isinstance(node, exp.Expression):
has_child_exprs = True
yield from _get_leaves(node)
@@ -268,7 +275,7 @@ def _expression_only_args(expression):
args = []
if expression:
for a in expression.args.values():
- args.extend(ensure_list(a))
+ args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]
diff --git a/sqlglot/errors.py b/sqlglot/errors.py
index 89aa935..2ef908f 100644
--- a/sqlglot/errors.py
+++ b/sqlglot/errors.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import typing as t
from enum import auto
from sqlglot.helper import AutoName
@@ -30,7 +33,11 @@ class OptimizeError(SqlglotError):
pass
-def concat_errors(errors, maximum):
+class SchemaError(SqlglotError):
+ pass
+
+
+def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
if remaining > 0:
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index d265a2c..393347b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -19,6 +19,7 @@ class Context:
env (Optional[dict]): dictionary of functions within the execution context
"""
self.tables = tables
+ self._table = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -29,8 +30,27 @@ class Context:
def eval_tuple(self, codes):
return tuple(self.eval(code) for code in codes)
+ @property
+ def table(self):
+ if self._table is None:
+ self._table = list(self.tables.values())[0]
+ for other in self.tables.values():
+ if self._table.columns != other.columns:
+ raise Exception(f"Columns are different.")
+ if len(self._table.rows) != len(other.rows):
+ raise Exception(f"Rows are different.")
+ return self._table
+
+ @property
+ def columns(self):
+ return self.table.columns
+
def __iter__(self):
- return self.table_iter(list(self.tables)[0])
+ self.env["scope"] = self.row_readers
+ for i in range(len(self.table.rows)):
+ for table in self.tables.values():
+ reader = table[i]
+ yield reader, self
def table_iter(self, table):
self.env["scope"] = self.row_readers
@@ -38,8 +58,8 @@ class Context:
for reader in self.tables[table]:
yield reader, self
- def sort(self, table, key):
- table = self.tables[table]
+ def sort(self, key):
+ table = self.table
def sort_key(row):
table.reader.row = row
@@ -47,20 +67,20 @@ class Context:
table.rows.sort(key=sort_key)
- def set_row(self, table, row):
- self.row_readers[table].row = row
+ def set_row(self, row):
+ for table in self.tables.values():
+ table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, table, index):
- self.row_readers[table].row = self.tables[table].rows[index]
+ def set_index(self, index):
+ for table in self.tables.values():
+ table[index]
self.env["scope"] = self.row_readers
- def set_range(self, table, start, end):
- self.range_readers[table].range = range(start, end)
+ def set_range(self, start, end):
+ for name in self.tables:
+ self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __getitem__(self, table):
- return self.env["scope"][table]
-
def __contains__(self, table):
return table in self.tables
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index 9c49dd1..bbe6c81 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -2,6 +2,8 @@ import datetime
import re
import statistics
+from sqlglot.helper import PYTHON_VERSION
+
class reverse_key:
def __init__(self, obj):
@@ -25,7 +27,7 @@ ENV = {
"str": str,
"desc": reverse_key,
"SUM": sum,
- "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
+ "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
"COUNT": lambda acc: sum(1 for e in acc if e is not None),
"MAX": max,
"MIN": min,
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index fcb016b..7d1db32 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -1,15 +1,14 @@
import ast
import collections
import itertools
+import math
-from sqlglot import exp, planner
+from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
from sqlglot.executor.table import Table
-from sqlglot.generator import Generator
from sqlglot.helper import csv_reader
-from sqlglot.tokens import Tokenizer
class PythonExecutor:
@@ -26,7 +25,11 @@ class PythonExecutor:
while queue:
node = queue.pop()
context = self.context(
- {name: table for dep in node.dependencies for name, table in contexts[dep].tables.items()}
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
)
running.add(node)
@@ -76,13 +79,10 @@ class PythonExecutor:
return Table(expression.alias_or_name for expression in expressions)
def scan(self, step, context):
- if hasattr(step, "source"):
- source = step.source
+ source = step.source
- if isinstance(source, exp.Expression):
- source = source.name or source.alias
- else:
- source = step.name
+ if isinstance(source, exp.Expression):
+ source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
@@ -96,14 +96,12 @@ class PythonExecutor:
if projections:
sink = self.table(step.projections)
- elif source in context:
- sink = Table(context[source].columns)
else:
sink = None
for reader, ctx in table_iter:
if sink is None:
- sink = Table(ctx[source].columns)
+ sink = Table(reader.columns)
if condition and not ctx.eval(condition):
continue
@@ -135,98 +133,79 @@ class PythonExecutor:
types.append(type(ast.literal_eval(v)))
except (ValueError, SyntaxError):
types.append(str)
- context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
- yield context[alias], context
+ context.set_row(tuple(t(v) for t, v in zip(types, row)))
+ yield context.table.reader, context
def join(self, step, context):
source = step.name
- join_context = self.context({source: context.tables[source]})
-
- def merge_context(ctx, table):
- # create a new context where all existing tables are mapped to a new one
- return self.context({name: table for name in ctx.tables})
+ source_table = context.tables[source]
+ source_context = self.context({source: source_table})
+ column_ranges = {source: range(0, len(source_table.columns))}
for name, join in step.joins.items():
- join_context = self.context({**join_context.tables, name: context.tables[name]})
+ table = context.tables[name]
+ start = max(r.stop for r in column_ranges.values())
+ column_ranges[name] = range(start, len(table.columns) + start)
+ join_context = self.context({name: table})
if join.get("source_key"):
- table = self.hash_join(join, source, name, join_context)
+ table = self.hash_join(join, source_context, join_context)
else:
- table = self.nested_loop_join(join, source, name, join_context)
+ table = self.nested_loop_join(join, source_context, join_context)
- join_context = merge_context(join_context, table)
-
- # apply projections or conditions
- context = self.scan(step, join_context)
+ source_context = self.context(
+ {
+ name: Table(table.columns, table.rows, column_range)
+ for name, column_range in column_ranges.items()
+ }
+ )
- # use the scan context since it returns a single table
- # otherwise there are no projections so all other tables are still in scope
- if step.projections:
- return context
+ condition = self.generate(step.condition)
+ projections = self.generate_tuple(step.projections)
- return merge_context(join_context, context.tables[source])
+ if not condition or not projections:
+ return source_context
- def nested_loop_join(self, _join, a, b, context):
- table = Table(context.tables[a].columns + context.tables[b].columns)
+ sink = self.table(step.projections if projections else source_context.columns)
- for reader_a, _ in context.table_iter(a):
- for reader_b, _ in context.table_iter(b):
- table.append(reader_a.row + reader_b.row)
+ for reader, ctx in join_context:
+ if condition and not ctx.eval(condition):
+ continue
- return table
+ if projections:
+ sink.append(ctx.eval_tuple(projections))
+ else:
+ sink.append(reader.row)
- def hash_join(self, join, a, b, context):
- a_key = self.generate_tuple(join["source_key"])
- b_key = self.generate_tuple(join["join_key"])
+ if len(sink) >= step.limit:
+ break
- results = collections.defaultdict(lambda: ([], []))
+ return self.context({step.name: sink})
- for reader, ctx in context.table_iter(a):
- results[ctx.eval_tuple(a_key)][0].append(reader.row)
- for reader, ctx in context.table_iter(b):
- results[ctx.eval_tuple(b_key)][1].append(reader.row)
+ def nested_loop_join(self, _join, source_context, join_context):
+ table = Table(source_context.columns + join_context.columns)
- table = Table(context.tables[a].columns + context.tables[b].columns)
- for a_group, b_group in results.values():
- for a_row, b_row in itertools.product(a_group, b_group):
- table.append(a_row + b_row)
+ for reader_a, _ in source_context:
+ for reader_b, _ in join_context:
+ table.append(reader_a.row + reader_b.row)
return table
- def sort_merge_join(self, join, a, b, context):
- a_key = self.generate_tuple(join["source_key"])
- b_key = self.generate_tuple(join["join_key"])
-
- context.sort(a, a_key)
- context.sort(b, b_key)
-
- a_i = 0
- b_i = 0
- a_n = len(context.tables[a])
- b_n = len(context.tables[b])
-
- table = Table(context.tables[a].columns + context.tables[b].columns)
-
- def get_key(source, key, i):
- context.set_index(source, i)
- return context.eval_tuple(key)
+ def hash_join(self, join, source_context, join_context):
+ source_key = self.generate_tuple(join["source_key"])
+ join_key = self.generate_tuple(join["join_key"])
- while a_i < a_n and b_i < b_n:
- key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
-
- a_group = []
-
- while a_i < a_n and key == get_key(a, a_key, a_i):
- a_group.append(context[a].row)
- a_i += 1
+ results = collections.defaultdict(lambda: ([], []))
- b_group = []
+ for reader, ctx in source_context:
+ results[ctx.eval_tuple(source_key)][0].append(reader.row)
+ for reader, ctx in join_context:
+ results[ctx.eval_tuple(join_key)][1].append(reader.row)
- while b_i < b_n and key == get_key(b, b_key, b_i):
- b_group.append(context[b].row)
- b_i += 1
+ table = Table(source_context.columns + join_context.columns)
+ for a_group, b_group in results.values():
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
@@ -238,16 +217,18 @@ class PythonExecutor:
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
- context.sort(source, group_by)
-
- if step.operands:
+ if operands:
source_table = context.tables[source]
operand_table = Table(source_table.columns + self.table(step.operands).columns)
for reader, ctx in context:
operand_table.append(reader.row + ctx.eval_tuple(operands))
- context = self.context({source: operand_table})
+ context = self.context(
+ {None: operand_table, **{table: operand_table for table in context.tables}}
+ )
+
+ context.sort(group_by)
group = None
start = 0
@@ -256,15 +237,15 @@ class PythonExecutor:
table = self.table(step.group + step.aggregations)
for i in range(length):
- context.set_index(source, i)
+ context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
if i == length - 1:
- context.set_range(source, start, end - 1)
+ context.set_range(start, end - 1)
elif key != group:
- context.set_range(source, start, end - 2)
+ context.set_range(start, end - 2)
else:
continue
@@ -272,13 +253,32 @@ class PythonExecutor:
group = key
start = end - 2
- return self.scan(step, self.context({source: table}))
+ context = self.context({step.name: table, **{name: table for name in context.tables}})
+
+ if step.projections:
+ return self.scan(step, context)
+ return context
def sort(self, step, context):
- table = list(context.tables)[0]
- key = self.generate_tuple(step.key)
- context.sort(table, key)
- return self.scan(step, context)
+ projections = self.generate_tuple(step.projections)
+
+ sink = self.table(step.projections)
+
+ for reader, ctx in context:
+ sink.append(ctx.eval_tuple(projections))
+
+ context = self.context(
+ {
+ None: sink,
+ **{table: sink for table in context.tables},
+ }
+ )
+ context.sort(self.generate_tuple(step.key))
+
+ if not math.isinf(step.limit):
+ context.table.rows = context.table.rows[0 : step.limit]
+
+ return self.context({step.name: context.table})
def _cast_py(self, expression):
@@ -293,7 +293,7 @@ def _cast_py(self, expression):
def _column_py(self, expression):
- table = self.sql(expression, "table")
+ table = self.sql(expression, "table") or None
this = self.sql(expression, "this")
return f"scope[{table}][{this}]"
@@ -319,10 +319,10 @@ def _ordered_py(self, expression):
class Python(Dialect):
- class Tokenizer(Tokenizer):
- ESCAPE = "\\"
+ class Tokenizer(tokens.Tokenizer):
+ ESCAPES = ["\\"]
- class Generator(Generator):
+ class Generator(generator.Generator):
TRANSFORMS = {
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 80674cb..6796740 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,10 +1,12 @@
class Table:
- def __init__(self, *columns, rows=None):
- self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
+ def __init__(self, columns, rows=None, column_range=None):
+ self.columns = tuple(columns)
+ self.column_range = column_range
+ self.reader = RowReader(self.columns, self.column_range)
+
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
- self.reader = RowReader(self.columns)
self.range_reader = RangeReader(self)
def append(self, row):
@@ -29,15 +31,22 @@ class Table:
return self.reader
def __repr__(self):
- widths = {column: len(column) for column in self.columns}
- lines = [" ".join(column for column in self.columns)]
+ columns = tuple(
+ column
+ for i, column in enumerate(self.columns)
+ if not self.column_range or i in self.column_range
+ )
+ widths = {column: len(column) for column in columns}
+ lines = [" ".join(column for column in columns)]
for i, row in enumerate(self):
if i > 10:
break
lines.append(
- " ".join(str(row[column]).rjust(widths[column])[0 : widths[column]] for column in self.columns)
+ " ".join(
+ str(row[column]).rjust(widths[column])[0 : widths[column]] for column in columns
+ )
)
return "\n".join(lines)
@@ -70,8 +79,10 @@ class RangeReader:
class RowReader:
- def __init__(self, columns):
- self.columns = {column: i for i, column in enumerate(columns)}
+ def __init__(self, columns, column_range=None):
+ self.columns = {
+ column: i for i, column in enumerate(columns) if not column_range or i in column_range
+ }
self.row = None
def __getitem__(self, column):
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 1691d85..57a2c88 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -1,6 +1,9 @@
+from __future__ import annotations
+
import datetime
import numbers
import re
+import typing as t
from collections import deque
from copy import deepcopy
from enum import auto
@@ -9,12 +12,15 @@ from sqlglot.errors import ParseError
from sqlglot.helper import (
AutoName,
camel_to_snake_case,
- ensure_list,
- list_get,
+ ensure_collection,
+ seq_get,
split_num_words,
subclasses,
)
+if t.TYPE_CHECKING:
+ from sqlglot.dialects.dialect import Dialect
+
class _Expression(type):
def __new__(cls, clsname, bases, attrs):
@@ -35,27 +41,30 @@ class Expression(metaclass=_Expression):
or optional (False).
"""
- key = None
+ key = "Expression"
arg_types = {"this": True}
- __slots__ = ("args", "parent", "arg_key", "type")
+ __slots__ = ("args", "parent", "arg_key", "type", "comment")
def __init__(self, **args):
self.args = args
self.parent = None
self.arg_key = None
self.type = None
+ self.comment = None
for arg_key, value in self.args.items():
self._set_parent(arg_key, value)
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
return type(self) is type(other) and _norm_args(self) == _norm_args(other)
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(
(
self.key,
- tuple((k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()),
+ tuple(
+ (k, tuple(v) if isinstance(v, list) else v) for k, v in _norm_args(self).items()
+ ),
)
)
@@ -79,6 +88,19 @@ class Expression(metaclass=_Expression):
return field.this
return ""
+ def find_comment(self, key: str) -> str:
+ """
+ Finds the comment that is attached to a specified child node.
+
+ Args:
+ key: the key of the target child node (e.g. "this", "expression", etc).
+
+ Returns:
+ The comment attached to the child node, or the empty string, if it doesn't exist.
+ """
+ field = self.args.get(key)
+ return field.comment if isinstance(field, Expression) else ""
+
@property
def is_string(self):
return isinstance(self, Literal) and self.args["is_string"]
@@ -114,7 +136,10 @@ class Expression(metaclass=_Expression):
return self.alias or self.name
def __deepcopy__(self, memo):
- return self.__class__(**deepcopy(self.args))
+ copy = self.__class__(**deepcopy(self.args))
+ copy.comment = self.comment
+ copy.type = self.type
+ return copy
def copy(self):
new = deepcopy(self)
@@ -249,9 +274,7 @@ class Expression(metaclass=_Expression):
return
for k, v in self.args.items():
- nodes = ensure_list(v)
-
- for node in nodes:
+ for node in ensure_collection(v):
if isinstance(node, Expression):
yield from node.dfs(self, k, prune)
@@ -274,9 +297,7 @@ class Expression(metaclass=_Expression):
if isinstance(item, Expression):
for k, v in item.args.items():
- nodes = ensure_list(v)
-
- for node in nodes:
+ for node in ensure_collection(v):
if isinstance(node, Expression):
queue.append((node, item, k))
@@ -319,7 +340,7 @@ class Expression(metaclass=_Expression):
def __repr__(self):
return self.to_s()
- def sql(self, dialect=None, **opts):
+ def sql(self, dialect: Dialect | str | None = None, **opts) -> str:
"""
Returns SQL string representation of this tree.
@@ -335,7 +356,7 @@ class Expression(metaclass=_Expression):
return Dialect.get_or_raise(dialect)().generate(self, **opts)
- def to_s(self, hide_missing=True, level=0):
+ def to_s(self, hide_missing: bool = True, level: int = 0) -> str:
indent = "" if not level else "\n"
indent += "".join([" "] * level)
left = f"({self.key.upper()} "
@@ -343,11 +364,13 @@ class Expression(metaclass=_Expression):
args = {
k: ", ".join(
v.to_s(hide_missing=hide_missing, level=level + 1) if hasattr(v, "to_s") else str(v)
- for v in ensure_list(vs)
+ for v in ensure_collection(vs)
if v is not None
)
for k, vs in self.args.items()
}
+ args["comment"] = self.comment
+ args["type"] = self.type
args = {k: v for k, v in args.items() if v or not hide_missing}
right = ", ".join(f"{k}: {v}" for k, v in args.items())
@@ -578,17 +601,6 @@ class UDTF(DerivedTable, Unionable):
pass
-class Annotation(Expression):
- arg_types = {
- "this": True,
- "expression": True,
- }
-
- @property
- def alias(self):
- return self.expression.alias_or_name
-
-
class Cache(Expression):
arg_types = {
"with": False,
@@ -623,6 +635,38 @@ class Describe(Expression):
pass
+class Set(Expression):
+ arg_types = {"expressions": True}
+
+
+class SetItem(Expression):
+ arg_types = {
+ "this": True,
+ "kind": False,
+ "collate": False, # MySQL SET NAMES statement
+ }
+
+
+class Show(Expression):
+ arg_types = {
+ "this": True,
+ "target": False,
+ "offset": False,
+ "limit": False,
+ "like": False,
+ "where": False,
+ "db": False,
+ "full": False,
+ "mutex": False,
+ "query": False,
+ "channel": False,
+ "global": False,
+ "log": False,
+ "position": False,
+ "types": False,
+ }
+
+
class UserDefinedFunction(Expression):
arg_types = {"this": True, "expressions": False}
@@ -864,18 +908,20 @@ class Literal(Condition):
def __eq__(self, other):
return (
- isinstance(other, Literal) and self.this == other.this and self.args["is_string"] == other.args["is_string"]
+ isinstance(other, Literal)
+ and self.this == other.this
+ and self.args["is_string"] == other.args["is_string"]
)
def __hash__(self):
return hash((self.key, self.this, self.args["is_string"]))
@classmethod
- def number(cls, number):
+ def number(cls, number) -> Literal:
return cls(this=str(number), is_string=False)
@classmethod
- def string(cls, string):
+ def string(cls, string) -> Literal:
return cls(this=str(string), is_string=True)
@@ -1087,7 +1133,7 @@ class Properties(Expression):
}
@classmethod
- def from_dict(cls, properties_dict):
+ def from_dict(cls, properties_dict) -> Properties:
expressions = []
for key, value in properties_dict.items():
property_cls = cls.PROPERTY_KEY_MAPPING.get(key.upper(), AnonymousProperty)
@@ -1323,7 +1369,7 @@ class Select(Subqueryable):
**QUERY_MODIFIERS,
}
- def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def from_(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the FROM expression.
@@ -1356,7 +1402,7 @@ class Select(Subqueryable):
**opts,
)
- def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the GROUP BY expression.
@@ -1392,7 +1438,7 @@ class Select(Subqueryable):
**opts,
)
- def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the ORDER BY expression.
@@ -1425,7 +1471,7 @@ class Select(Subqueryable):
**opts,
)
- def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the SORT BY expression.
@@ -1458,7 +1504,7 @@ class Select(Subqueryable):
**opts,
)
- def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Set the CLUSTER BY expression.
@@ -1491,7 +1537,7 @@ class Select(Subqueryable):
**opts,
)
- def limit(self, expression, dialect=None, copy=True, **opts):
+ def limit(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the LIMIT expression.
@@ -1522,7 +1568,7 @@ class Select(Subqueryable):
**opts,
)
- def offset(self, expression, dialect=None, copy=True, **opts):
+ def offset(self, expression, dialect=None, copy=True, **opts) -> Select:
"""
Set the OFFSET expression.
@@ -1553,7 +1599,7 @@ class Select(Subqueryable):
**opts,
)
- def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def select(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the SELECT expressions.
@@ -1583,7 +1629,7 @@ class Select(Subqueryable):
**opts,
)
- def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the LATERAL expressions.
@@ -1626,7 +1672,7 @@ class Select(Subqueryable):
dialect=None,
copy=True,
**opts,
- ):
+ ) -> Select:
"""
Append to or set the JOIN expressions.
@@ -1672,7 +1718,7 @@ class Select(Subqueryable):
join.this.replace(join.this.subquery())
if join_type:
- natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ natural, side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) # type: ignore
if natural:
join.set("natural", True)
if side:
@@ -1681,12 +1727,12 @@ class Select(Subqueryable):
join.set("kind", kind.text)
if on:
- on = and_(*ensure_list(on), dialect=dialect, **opts)
+ on = and_(*ensure_collection(on), dialect=dialect, **opts)
join.set("on", on)
if using:
join = _apply_list_builder(
- *ensure_list(using),
+ *ensure_collection(using),
instance=join,
arg="using",
append=append,
@@ -1705,7 +1751,7 @@ class Select(Subqueryable):
**opts,
)
- def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def where(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the WHERE expressions.
@@ -1737,7 +1783,7 @@ class Select(Subqueryable):
**opts,
)
- def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ def having(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select:
"""
Append to or set the HAVING expressions.
@@ -1769,7 +1815,7 @@ class Select(Subqueryable):
**opts,
)
- def distinct(self, distinct=True, copy=True):
+ def distinct(self, distinct=True, copy=True) -> Select:
"""
Set the OFFSET expression.
@@ -1788,7 +1834,7 @@ class Select(Subqueryable):
instance.set("distinct", Distinct() if distinct else None)
return instance
- def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
+ def ctas(self, table, properties=None, dialect=None, copy=True, **opts) -> Create:
"""
Convert this expression to a CREATE TABLE AS statement.
@@ -1826,11 +1872,11 @@ class Select(Subqueryable):
)
@property
- def named_selects(self):
+ def named_selects(self) -> t.List[str]:
return [e.alias_or_name for e in self.expressions if e.alias_or_name]
@property
- def selects(self):
+ def selects(self) -> t.List[Expression]:
return self.expressions
@@ -1910,12 +1956,16 @@ class Parameter(Expression):
pass
+class SessionParameter(Expression):
+ arg_types = {"this": True, "kind": False}
+
+
class Placeholder(Expression):
arg_types = {"this": False}
class Null(Condition):
- arg_types = {}
+ arg_types: t.Dict[str, t.Any] = {}
class Boolean(Condition):
@@ -1936,6 +1986,7 @@ class DataType(Expression):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
+ VARBINARY = auto()
INT = auto()
TINYINT = auto()
SMALLINT = auto()
@@ -1975,7 +2026,7 @@ class DataType(Expression):
UNKNOWN = auto() # Sentinel value, useful for type annotation
@classmethod
- def build(cls, dtype, **kwargs):
+ def build(cls, dtype, **kwargs) -> DataType:
return DataType(
this=dtype if isinstance(dtype, DataType.Type) else DataType.Type[dtype.upper()],
**kwargs,
@@ -2077,6 +2128,18 @@ class EQ(Binary, Predicate):
pass
+class NullSafeEQ(Binary, Predicate):
+ pass
+
+
+class NullSafeNEQ(Binary, Predicate):
+ pass
+
+
+class Distance(Binary):
+ pass
+
+
class Escape(Binary):
pass
@@ -2101,15 +2164,11 @@ class Is(Binary, Predicate):
pass
-class Like(Binary, Predicate):
- pass
-
-
-class SimilarTo(Binary, Predicate):
- pass
+class Kwarg(Binary):
+ """Kwarg in special functions like func(kwarg => y)."""
-class Distance(Binary):
+class Like(Binary, Predicate):
pass
@@ -2133,6 +2192,10 @@ class NEQ(Binary, Predicate):
pass
+class SimilarTo(Binary, Predicate):
+ pass
+
+
class Sub(Binary):
pass
@@ -2189,7 +2252,13 @@ class Distinct(Expression):
class In(Predicate):
- arg_types = {"this": True, "expressions": False, "query": False, "unnest": False, "field": False}
+ arg_types = {
+ "this": True,
+ "expressions": False,
+ "query": False,
+ "unnest": False,
+ "field": False,
+ }
class TimeUnit(Expression):
@@ -2255,7 +2324,9 @@ class Func(Condition):
@classmethod
def sql_names(cls):
if cls is Func:
- raise NotImplementedError("SQL name is only supported by concrete function implementations")
+ raise NotImplementedError(
+ "SQL name is only supported by concrete function implementations"
+ )
if not hasattr(cls, "_sql_names"):
cls._sql_names = [camel_to_snake_case(cls.__name__)]
return cls._sql_names
@@ -2408,8 +2479,8 @@ class DateDiff(Func, TimeUnit):
arg_types = {"this": True, "expression": True, "unit": False}
-class DateTrunc(Func, TimeUnit):
- arg_types = {"this": True, "unit": True, "zone": False}
+class DateTrunc(Func):
+ arg_types = {"this": True, "expression": True, "zone": False}
class DatetimeAdd(Func, TimeUnit):
@@ -2791,6 +2862,10 @@ class Year(Func):
pass
+class Use(Expression):
+ pass
+
+
def _norm_args(expression):
args = {}
@@ -2822,7 +2897,7 @@ def maybe_parse(
dialect=None,
prefix=None,
**opts,
-):
+) -> t.Optional[Expression]:
"""Gracefully handle a possible string or expression.
Example:
@@ -3073,7 +3148,7 @@ def except_(left, right, distinct=True, dialect=None, **opts):
return Except(this=left, expression=right, distinct=distinct)
-def select(*expressions, dialect=None, **opts):
+def select(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from one or multiple SELECT expressions.
@@ -3095,7 +3170,7 @@ def select(*expressions, dialect=None, **opts):
return Select().select(*expressions, dialect=dialect, **opts)
-def from_(*expressions, dialect=None, **opts):
+def from_(*expressions, dialect=None, **opts) -> Select:
"""
Initializes a syntax tree from a FROM expression.
@@ -3117,7 +3192,7 @@ def from_(*expressions, dialect=None, **opts):
return Select().from_(*expressions, dialect=dialect, **opts)
-def update(table, properties, where=None, from_=None, dialect=None, **opts):
+def update(table, properties, where=None, from_=None, dialect=None, **opts) -> Update:
"""
Creates an update statement.
@@ -3139,7 +3214,10 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
update = Update(this=maybe_parse(table, into=Table, dialect=dialect))
update.set(
"expressions",
- [EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v)) for k, v in properties.items()],
+ [
+ EQ(this=maybe_parse(k, dialect=dialect, **opts), expression=convert(v))
+ for k, v in properties.items()
+ ],
)
if from_:
update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts))
@@ -3150,7 +3228,7 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts):
return update
-def delete(table, where=None, dialect=None, **opts):
+def delete(table, where=None, dialect=None, **opts) -> Delete:
"""
Builds a delete statement.
@@ -3174,7 +3252,7 @@ def delete(table, where=None, dialect=None, **opts):
)
-def condition(expression, dialect=None, **opts):
+def condition(expression, dialect=None, **opts) -> Condition:
"""
Initialize a logical condition expression.
@@ -3199,7 +3277,7 @@ def condition(expression, dialect=None, **opts):
Returns:
Condition: the expression
"""
- return maybe_parse(
+ return maybe_parse( # type: ignore
expression,
into=Condition,
dialect=dialect,
@@ -3207,7 +3285,7 @@ def condition(expression, dialect=None, **opts):
)
-def and_(*expressions, dialect=None, **opts):
+def and_(*expressions, dialect=None, **opts) -> And:
"""
Combine multiple conditions with an AND logical operator.
@@ -3227,7 +3305,7 @@ def and_(*expressions, dialect=None, **opts):
return _combine(expressions, And, dialect, **opts)
-def or_(*expressions, dialect=None, **opts):
+def or_(*expressions, dialect=None, **opts) -> Or:
"""
Combine multiple conditions with an OR logical operator.
@@ -3247,7 +3325,7 @@ def or_(*expressions, dialect=None, **opts):
return _combine(expressions, Or, dialect, **opts)
-def not_(expression, dialect=None, **opts):
+def not_(expression, dialect=None, **opts) -> Not:
"""
Wrap a condition with a NOT operator.
@@ -3272,14 +3350,14 @@ def not_(expression, dialect=None, **opts):
return Not(this=_wrap_operator(this))
-def paren(expression):
+def paren(expression) -> Paren:
return Paren(this=expression)
SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
-def to_identifier(alias, quoted=None):
+def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
if alias is None:
return None
if isinstance(alias, Identifier):
@@ -3293,16 +3371,16 @@ def to_identifier(alias, quoted=None):
return identifier
-def to_table(sql_path: str, **kwargs) -> Table:
+def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
-
If a table is passed in then that table is returned.
Args:
- sql_path(str|Table): `[catalog].[schema].[table]` string
+ sql_path: a `[catalog].[schema].[table]` string.
+
Returns:
- Table: A table expression
+ A table expression.
"""
if sql_path is None or isinstance(sql_path, Table):
return sql_path
@@ -3393,7 +3471,7 @@ def subquery(expression, alias=None, dialect=None, **opts):
return Select().from_(expression, dialect=dialect, **opts)
-def column(col, table=None, quoted=None):
+def column(col, table=None, quoted=None) -> Column:
"""
Build a Column.
Args:
@@ -3408,7 +3486,7 @@ def column(col, table=None, quoted=None):
)
-def table_(table, db=None, catalog=None, quoted=None, alias=None):
+def table_(table, db=None, catalog=None, quoted=None, alias=None) -> Table:
"""Build a Table.
Args:
@@ -3427,7 +3505,7 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None):
)
-def values(values, alias=None):
+def values(values, alias=None) -> Values:
"""Build VALUES statement.
Example:
@@ -3449,7 +3527,7 @@ def values(values, alias=None):
)
-def convert(value):
+def convert(value) -> Expression:
"""Convert a python value into an expression object.
Raises an error if a conversion is not possible.
@@ -3500,15 +3578,14 @@ def replace_children(expression, fun):
for cn in child_nodes:
if isinstance(cn, Expression):
- cns = ensure_list(fun(cn))
- for child_node in cns:
+ for child_node in ensure_collection(fun(cn)):
new_child_nodes.append(child_node)
child_node.parent = expression
child_node.arg_key = k
else:
new_child_nodes.append(cn)
- expression.args[k] = new_child_nodes if is_list_arg else list_get(new_child_nodes, 0)
+ expression.args[k] = new_child_nodes if is_list_arg else seq_get(new_child_nodes, 0)
def column_table_names(expression):
@@ -3529,7 +3606,7 @@ def column_table_names(expression):
return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
-def table_name(table):
+def table_name(table) -> str:
"""Get the full name of a table as a string.
Args:
@@ -3546,6 +3623,9 @@ def table_name(table):
table = maybe_parse(table, into=Table)
+ if not table:
+ raise ValueError(f"Cannot parse {table}")
+
return ".".join(
part
for part in (
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index ca14425..11d9073 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,4 +1,8 @@
+from __future__ import annotations
+
import logging
+import re
+import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
@@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
+NEWLINE_RE = re.compile("\r\n?|\n")
+
class Generator:
"""
@@ -47,8 +53,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
- annotations: Whether or not to show annotations in the SQL when `pretty` is True.
- Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
+ comments: Whether or not to preserve comments in the ouput SQL code.
Default: True
"""
@@ -65,14 +70,16 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
- # whether 'CREATE ... TRANSIENT ... TABLE' is allowed
- # can override in dialects
+ # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
CREATE_TRANSIENT = False
- # whether or not null ordering is supported in order by
+
+ # Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
- # always do union distinct or union all
+
+ # Always do union distinct or union all
EXPLICIT_UNION = False
- # wrap derived values in parens, usually standard but spark doesn't support it
+
+ # Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
TYPE_MAPPING = {
@@ -80,7 +87,7 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR",
}
- TOKEN_MAPPING = {}
+ TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@@ -96,6 +103,8 @@ class Generator:
exp.TableFormatProperty,
}
+ WITH_SEPARATED_COMMENTS = (exp.Select,)
+
__slots__ = (
"time_mapping",
"time_trie",
@@ -122,7 +131,7 @@ class Generator:
"_escaped_quote_end",
"_leading_comma",
"_max_text_width",
- "_annotations",
+ "_comments",
)
def __init__(
@@ -148,7 +157,7 @@ class Generator:
max_unsupported=3,
leading_comma=False,
max_text_width=80,
- annotations=True,
+ comments=True,
):
import sqlglot
@@ -177,7 +186,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
- self._annotations = annotations
+ self._comments = comments
def generate(self, expression):
"""
@@ -204,7 +213,6 @@ class Generator:
return sql
def unsupported(self, message):
-
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
@@ -215,9 +223,31 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
+ def maybe_comment(self, sql, expression, single_line=False):
+ comment = expression.comment if self._comments else None
+
+ if not comment:
+ return sql
+
+ comment = " " + comment if comment[0].strip() else comment
+ comment = comment + " " if comment[-1].strip() else comment
+
+ if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
+ return f"/*{comment}*/{self.sep()}{sql}"
+
+ if not self.pretty:
+ return f"{sql} /*{comment}*/"
+
+ if not NEWLINE_RE.search(comment):
+ return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
+
+ return f"/*{comment}*/\n{sql}"
+
def wrap(self, expression):
this_sql = self.indent(
- self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
+ self.sql(expression)
+ if isinstance(expression, (exp.Select, exp.Union))
+ else self.sql(expression, "this"),
level=1,
pad=0,
)
@@ -251,7 +281,7 @@ class Generator:
for i, line in enumerate(lines)
)
- def sql(self, expression, key=None):
+ def sql(self, expression, key=None, comment=True):
if not expression:
return ""
@@ -264,29 +294,24 @@ class Generator:
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
- return transform(self, expression)
- if transform:
- return transform
-
- if not isinstance(expression, exp.Expression):
+ sql = transform(self, expression)
+ elif transform:
+ sql = transform
+ elif isinstance(expression, exp.Expression):
+ exp_handler_name = f"{expression.key}_sql"
+
+ if hasattr(self, exp_handler_name):
+ sql = getattr(self, exp_handler_name)(expression)
+ elif isinstance(expression, exp.Func):
+ sql = self.function_fallback_sql(expression)
+ elif isinstance(expression, exp.Property):
+ sql = self.property_sql(expression)
+ else:
+ raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
+ else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- exp_handler_name = f"{expression.key}_sql"
- if hasattr(self, exp_handler_name):
- return getattr(self, exp_handler_name)(expression)
-
- if isinstance(expression, exp.Func):
- return self.function_fallback_sql(expression)
-
- if isinstance(expression, exp.Property):
- return self.property_sql(expression)
-
- raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
-
- def annotation_sql(self, expression):
- if self._annotations and self.pretty:
- return f"{self.sql(expression, 'expression')} # {expression.name}"
- return self.sql(expression, "expression")
+ return self.maybe_comment(sql, expression) if self._comments and comment else sql
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@@ -371,7 +396,9 @@ class Generator:
expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
- transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
+ transient = (
+ " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
+ )
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
@@ -434,7 +461,9 @@ class Generator:
def delete_sql(self, expression):
this = self.sql(expression, "this")
using_sql = (
- f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
+ f" USING {self.expressions(expression, 'using', sep=', USING ')}"
+ if expression.args.get("using")
+ else ""
)
where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
@@ -481,15 +510,18 @@ class Generator:
return f"{this} ON {table} {columns}"
def identifier_sql(self, expression):
- value = expression.name
- value = value.lower() if self.normalize else value
+ text = expression.name
+ text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
- return f"{self.identifier_start}{value}{self.identifier_end}"
- return value
+ text = f"{self.identifier_start}{text}{self.identifier_end}"
+ return text
def partition_sql(self, expression):
keys = csv(
- *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
+ *[
+ f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
+ for prop in expression.this
+ ]
)
return f"PARTITION({keys})"
@@ -504,9 +536,9 @@ class Generator:
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
- return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
- exp.Properties(expressions=with_properties)
- )
+ return self.root_properties(
+ exp.Properties(expressions=root_properties)
+ ) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties):
if properties.expressions:
@@ -551,7 +583,9 @@ class Generator:
this = f"{this}{self.sql(expression, 'this')}"
exists = " IF EXISTS " if expression.args.get("exists") else " "
- partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
+ partition_sql = (
+ self.sql(expression, "partition") if expression.args.get("partition") else ""
+ )
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
@@ -669,7 +703,9 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
- grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
+ grouping_sets = (
+ f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
+ )
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@@ -711,10 +747,10 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
- def lambda_sql(self, expression):
+ def lambda_sql(self, expression, arrow_sep="->"):
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
- return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
+ return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
def lateral_sql(self, expression):
this = self.sql(expression, "this")
@@ -748,7 +784,7 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
- return f"{self.quote_start}{text}{self.quote_end}"
+ text = f"{self.quote_start}{text}{self.quote_end}"
return text
def loaddata_sql(self, expression):
@@ -796,13 +832,21 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
- if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
+ if nulls_first and (
+ (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
+ ):
nulls_sort_change = " NULLS FIRST"
- elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
+ elif (
+ nulls_last
+ and ((asc and nulls_are_small) or (desc and nulls_are_large))
+ and not nulls_are_last
+ ):
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
+ self.unsupported(
+ "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
+ )
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@@ -835,7 +879,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
- self.sql(expression, "from"),
+ self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@@ -858,6 +902,13 @@ class Generator:
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
+ def sessionparameter_sql(self, expression):
+ this = self.sql(expression, "this")
+ kind = expression.text("kind")
+ if kind:
+ kind = f"{kind}."
+ return f"@@{kind}{this}"
+
def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?"
@@ -931,7 +982,10 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
- end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
+ end = (
+ csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
+ or "CURRENT ROW"
+ )
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@@ -1020,7 +1074,9 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
- return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
+ return self.case_sql(
+ exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
+ )
def in_sql(self, expression):
query = expression.args.get("query")
@@ -1196,6 +1252,12 @@ class Generator:
def neq_sql(self, expression):
return self.binary(expression, "<>")
+ def nullsafeeq_sql(self, expression):
+ return self.binary(expression, "IS NOT DISTINCT FROM")
+
+ def nullsafeneq_sql(self, expression):
+ return self.binary(expression, "IS DISTINCT FROM")
+
def or_sql(self, expression):
return self.connector_sql(expression, "OR")
@@ -1205,6 +1267,9 @@ class Generator:
def trycast_sql(self, expression):
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
+ def use_sql(self, expression):
+ return f"USE {self.sql(expression, 'this')}"
+
def binary(self, expression, op):
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
@@ -1240,17 +1305,27 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
- sql = (self.sql(e) for e in expressions)
- # the only time leading_comma changes the output is if pretty print is enabled
- if self._leading_comma and self.pretty:
- pad = " " * self.pad
- expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
- else:
- expressions = self.sep(sep).join(sql)
+ num_sqls = len(expressions)
+
+ # These are calculated once in case we have the leading_comma / pretty option set, correspondingly
+ pad = " " * self.pad
+ stripped_sep = sep.strip()
- if indent:
- return self.indent(expressions, skip_first=False)
- return expressions
+ result_sqls = []
+ for i, e in enumerate(expressions):
+ sql = self.sql(e, comment=False)
+ comment = self.maybe_comment("", e, single_line=True)
+
+ if self.pretty:
+ if self._leading_comma:
+ result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
+ else:
+ result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
+ else:
+ result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
+
+ result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
+ return self.indent(result_sqls, skip_first=False) if indent else result_sqls
def op_expressions(self, op, expression, flat=False):
expressions_sql = self.expressions(expression, flat=flat)
@@ -1264,7 +1339,9 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
- return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
+ return self.query_modifiers(
+ expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
+ )
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
@@ -1283,3 +1360,6 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
+
+ def kwarg_sql(self, expression):
+ return self.binary(expression, "=>")
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 42965d1..379c2e7 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -1,48 +1,125 @@
+from __future__ import annotations
+
import inspect
import logging
import re
import sys
import typing as t
+from collections.abc import Collection
from contextlib import contextmanager
from copy import copy
from enum import Enum
+if t.TYPE_CHECKING:
+ from sqlglot.expressions import Expression, Table
+
+ T = t.TypeVar("T")
+ E = t.TypeVar("E", bound=Expression)
+
CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
+PYTHON_VERSION = sys.version_info[:2]
logger = logging.getLogger("sqlglot")
class AutoName(Enum):
- def _generate_next_value_(name, _start, _count, _last_values):
+ """This is used for creating enum classes where `auto()` is the string form of the corresponding value's name."""
+
+ def _generate_next_value_(name, _start, _count, _last_values): # type: ignore
return name
-def list_get(arr, index):
+def seq_get(seq: t.Sequence[T], index: int) -> t.Optional[T]:
+ """Returns the value in `seq` at position `index`, or `None` if `index` is out of bounds."""
try:
- return arr[index]
+ return seq[index]
except IndexError:
return None
+@t.overload
+def ensure_list(value: t.Collection[T]) -> t.List[T]:
+ ...
+
+
+@t.overload
+def ensure_list(value: T) -> t.List[T]:
+ ...
+
+
def ensure_list(value):
+ """
+ Ensures that a value is a list, otherwise casts or wraps it into one.
+
+ Args:
+ value: the value of interest.
+
+ Returns:
+ The value cast as a list if it's a list or a tuple, or else the value wrapped in a list.
+ """
if value is None:
return []
- return value if isinstance(value, (list, tuple, set)) else [value]
+ elif isinstance(value, (list, tuple)):
+ return list(value)
+
+ return [value]
+
+
+@t.overload
+def ensure_collection(value: t.Collection[T]) -> t.Collection[T]:
+ ...
-def csv(*args, sep=", "):
+@t.overload
+def ensure_collection(value: T) -> t.Collection[T]:
+ ...
+
+
+def ensure_collection(value):
+ """
+ Ensures that a value is a collection (excluding `str` and `bytes`), otherwise wraps it into a list.
+
+ Args:
+ value: the value of interest.
+
+ Returns:
+ The value if it's a collection, or else the value wrapped in a list.
+ """
+ if value is None:
+ return []
+ return (
+ value if isinstance(value, Collection) and not isinstance(value, (str, bytes)) else [value]
+ )
+
+
+def csv(*args, sep: str = ", ") -> str:
+ """
+ Formats any number of string arguments as CSV.
+
+ Args:
+ args: the string arguments to format.
+ sep: the argument separator.
+
+ Returns:
+ The arguments formatted as a CSV string.
+ """
return sep.join(arg for arg in args if arg)
-def subclasses(module_name, classes, exclude=()):
+def subclasses(
+ module_name: str,
+ classes: t.Type | t.Tuple[t.Type, ...],
+ exclude: t.Type | t.Tuple[t.Type, ...] = (),
+) -> t.List[t.Type]:
"""
- Returns a list of all subclasses for a specified class set, posibly excluding some of them.
+ Returns all subclasses for a collection of classes, possibly excluding some of them.
Args:
- module_name (str): The name of the module to search for subclasses in.
- classes (type|tuple[type]): Class(es) we want to find the subclasses of.
- exclude (type|tuple[type]): Class(es) we want to exclude from the returned list.
+ module_name: the name of the module to search for subclasses in.
+ classes: class(es) we want to find the subclasses of.
+ exclude: class(es) we want to exclude from the returned list.
+
Returns:
- A list of all the target subclasses.
+ The target subclasses.
"""
return [
obj
@@ -53,7 +130,18 @@ def subclasses(module_name, classes, exclude=()):
]
-def apply_index_offset(expressions, offset):
+def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
+ """
+ Applies an offset to a given integer literal expression.
+
+ Args:
+ expressions: the expression the offset will be applied to, wrapped in a list.
+ offset: the offset that will be applied.
+
+ Returns:
+ The original expression with the offset applied to it, wrapped in a list. If the provided
+ `expressions` argument contains more than one expressions, it's returned unaffected.
+ """
if not offset or len(expressions) != 1:
return expressions
@@ -64,14 +152,28 @@ def apply_index_offset(expressions, offset):
logger.warning("Applying array index offset (%s)", offset)
expression.args["this"] = str(int(expression.args["this"]) + offset)
return [expression]
+
return expressions
-def camel_to_snake_case(name):
+def camel_to_snake_case(name: str) -> str:
+ """Converts `name` from camelCase to snake_case and returns the result."""
return CAMEL_CASE_PATTERN.sub("_", name).upper()
-def while_changing(expression, func):
+def while_changing(
+ expression: t.Optional[Expression], func: t.Callable[[t.Optional[Expression]], E]
+) -> E:
+ """
+ Applies a transformation to a given expression until a fix point is reached.
+
+ Args:
+ expression: the expression to be transformed.
+ func: the transformation to be applied.
+
+ Returns:
+ The transformed expression.
+ """
while True:
start = hash(expression)
expression = func(expression)
@@ -80,10 +182,19 @@ def while_changing(expression, func):
return expression
-def tsort(dag):
+def tsort(dag: t.Dict[T, t.List[T]]) -> t.List[T]:
+ """
+ Sorts a given directed acyclic graph in topological order.
+
+ Args:
+ dag: the graph to be sorted.
+
+ Returns:
+ A list that contains all of the graph's nodes in topological order.
+ """
result = []
- def visit(node, visited):
+ def visit(node: T, visited: t.Set[T]) -> None:
if node in result:
return
if node in visited:
@@ -103,10 +214,8 @@ def tsort(dag):
return result
-def open_file(file_name):
- """
- Open a file that may be compressed as gzip and return in newline mode.
- """
+def open_file(file_name: str) -> t.TextIO:
+ """Open a file that may be compressed as gzip and return it in universal newline mode."""
with open(file_name, "rb") as f:
gzipped = f.read(2) == b"\x1f\x8b"
@@ -119,14 +228,14 @@ def open_file(file_name):
@contextmanager
-def csv_reader(table):
+def csv_reader(table: Table) -> t.Any:
"""
- Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
+ Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
- table (exp.Table): A table expression with an anonymous function READ_CSV in it
+ table: a `Table` expression with an anonymous function `READ_CSV` in it.
- Returns:
+ Yields:
A python csv reader.
"""
file, *args = table.this.expressions
@@ -147,13 +256,16 @@ def csv_reader(table):
file.close()
-def find_new_name(taken, base):
+def find_new_name(taken: t.Sequence[str], base: str) -> str:
"""
Searches for a new name.
Args:
- taken (Sequence[str]): set of taken names
- base (str): base name to alter
+ taken: a collection of taken names.
+ base: base name to alter.
+
+ Returns:
+ The new, available name.
"""
if base not in taken:
return base
@@ -163,22 +275,26 @@ def find_new_name(taken, base):
while new in taken:
i += 1
new = f"{base}_{i}"
+
return new
-def object_to_dict(obj, **kwargs):
+def object_to_dict(obj: t.Any, **kwargs) -> t.Dict:
+ """Returns a dictionary created from an object's attributes."""
return {**{k: copy(v) for k, v in vars(obj).copy().items()}, **kwargs}
-def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: bool = True) -> t.List[t.Optional[str]]:
+def split_num_words(
+ value: str, sep: str, min_num_words: int, fill_from_start: bool = True
+) -> t.List[t.Optional[str]]:
"""
- Perform a split on a value and return N words as a result with None used for words that don't exist.
+ Perform a split on a value and return N words as a result with `None` used for words that don't exist.
Args:
- value: The value to be split
- sep: The value to use to split on
- min_num_words: The minimum number of words that are going to be in the result
- fill_from_start: Indicates that if None values should be inserted at the start or end of the list
+ value: the value to be split.
+ sep: the value to use to split on.
+ min_num_words: the minimum number of words that are going to be in the result.
+ fill_from_start: indicates that if `None` values should be inserted at the start or end of the list.
Examples:
>>> split_num_words("db.table", ".", 3)
@@ -187,6 +303,9 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
['db', 'table', None]
>>> split_num_words("db.table", ".", 1)
['db', 'table']
+
+ Returns:
+ The list of words returned by `split`, possibly augmented by a number of `None` values.
"""
words = value.split(sep)
if fill_from_start:
@@ -196,7 +315,7 @@ def split_num_words(value: str, sep: str, min_num_words: int, fill_from_start: b
def is_iterable(value: t.Any) -> bool:
"""
- Checks if the value is an iterable but does not include strings and bytes
+ Checks if the value is an iterable, excluding the types `str` and `bytes`.
Examples:
>>> is_iterable([1,2])
@@ -205,28 +324,30 @@ def is_iterable(value: t.Any) -> bool:
False
Args:
- value: The value to check if it is an interable
+ value: the value to check if it is an iterable.
- Returns: Bool indicating if it is an iterable
+ Returns:
+ A `bool` value indicating if it is an iterable.
"""
return hasattr(value, "__iter__") and not isinstance(value, (str, bytes))
-def flatten(values: t.Iterable[t.Union[t.Iterable[t.Any], t.Any]]) -> t.Generator[t.Any, None, None]:
+def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any, None, None]:
"""
- Flattens a list that can contain both iterables and non-iterable elements
+ Flattens an iterable that can contain both iterable and non-iterable elements. Objects of
+ type `str` and `bytes` are not regarded as iterables.
Examples:
- >>> list(flatten([[1, 2], 3]))
- [1, 2, 3]
+ >>> list(flatten([[1, 2], 3, {4}, (5, "bla")]))
+ [1, 2, 3, 4, 5, 'bla']
>>> list(flatten([1, 2, 3]))
[1, 2, 3]
Args:
- values: The value to be flattened
+ values: the value to be flattened.
- Returns:
- Yields non-iterable elements (not including str or byte as iterable)
+ Yields:
+ Non-iterable elements in `values`.
"""
for value in values:
if is_iterable(value):
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 30055bc..96331e2 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -1,5 +1,5 @@
from sqlglot import exp
-from sqlglot.helper import ensure_list, subclasses
+from sqlglot.helper import ensure_collection, ensure_list, subclasses
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -48,35 +48,65 @@ class TypeAnnotator:
exp.Boolean: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
exp.Null: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.NULL),
exp.Anonymous: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.UNKNOWN),
- exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
+ exp.ApproxDistinct: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.BIGINT
+ ),
exp.Avg: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Ceil: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Count: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.CurrentDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
- exp.CurrentTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.CurrentDatetime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
+ exp.CurrentTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.CurrentTimestamp: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.DateAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.DateDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
- exp.DatetimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATETIME),
+ exp.DatetimeAdd: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
+ exp.DatetimeSub: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATETIME
+ ),
exp.DatetimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Extract: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.TimestampAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.TimestampSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimestampAdd: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.TimestampSub: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.TimestampDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.TimeAdd: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeSub: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
exp.TimeDiff: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.DateStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.DateToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.DateStrToDate: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATE
+ ),
+ exp.DateToDateStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.DateToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.Day: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.DiToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.Exp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Floor: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
- exp.If: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.Case: lambda self, expr: self._annotate_by_args(expr, "default", "ifs"),
+ exp.If: lambda self, expr: self._annotate_by_args(expr, "true", "false"),
+ exp.Coalesce: lambda self, expr: self._annotate_by_args(expr, "this", "expressions"),
+ exp.IfNull: lambda self, expr: self._annotate_by_args(expr, "this", "expression"),
+ exp.ConcatWs: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.GroupConcat: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
+ exp.ArrayConcat: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.Initcap: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Length: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BIGINT),
exp.Levenshtein: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
@@ -88,32 +118,52 @@ class TypeAnnotator:
exp.Month: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Pow: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Quantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.RegexpLike: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.BOOLEAN),
+ exp.ApproxQuantile: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DOUBLE
+ ),
+ exp.RegexpLike: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.BOOLEAN
+ ),
exp.Round: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.SafeDivide: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Substring: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.StrPosition: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.StrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.StrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.StrToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.Sqrt: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.Stddev: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevPop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.StddevSamp: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
exp.TimeToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
- exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
+ exp.TimeToTimeStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
+ exp.TimeStrToDate: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DATE
+ ),
+ exp.TimeStrToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
exp.Trim: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.TsOrDsToDateStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.TsOrDsToDate: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DATE),
exp.TsOrDiToDi: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.INT),
exp.UnixToStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
- exp.UnixToTime: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TIMESTAMP),
- exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
+ exp.UnixToTime: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.TIMESTAMP
+ ),
+ exp.UnixToTimeStr: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.VARCHAR
+ ),
exp.Upper: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.VARCHAR),
exp.Variance: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
- exp.VariancePop: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.DOUBLE),
+ exp.VariancePop: lambda self, expr: self._annotate_with_type(
+ expr, exp.DataType.Type.DOUBLE
+ ),
exp.Week: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
exp.Year: lambda self, expr: self._annotate_with_type(expr, exp.DataType.Type.TINYINT),
}
@@ -124,7 +174,11 @@ class TypeAnnotator:
exp.DataType.Type.TEXT: set(),
exp.DataType.Type.NVARCHAR: {exp.DataType.Type.TEXT},
exp.DataType.Type.VARCHAR: {exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
- exp.DataType.Type.NCHAR: {exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR, exp.DataType.Type.TEXT},
+ exp.DataType.Type.NCHAR: {
+ exp.DataType.Type.VARCHAR,
+ exp.DataType.Type.NVARCHAR,
+ exp.DataType.Type.TEXT,
+ },
exp.DataType.Type.CHAR: {
exp.DataType.Type.NCHAR,
exp.DataType.Type.VARCHAR,
@@ -135,7 +189,11 @@ class TypeAnnotator:
exp.DataType.Type.DOUBLE: set(),
exp.DataType.Type.FLOAT: {exp.DataType.Type.DOUBLE},
exp.DataType.Type.DECIMAL: {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
- exp.DataType.Type.BIGINT: {exp.DataType.Type.DECIMAL, exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE},
+ exp.DataType.Type.BIGINT: {
+ exp.DataType.Type.DECIMAL,
+ exp.DataType.Type.FLOAT,
+ exp.DataType.Type.DOUBLE,
+ },
exp.DataType.Type.INT: {
exp.DataType.Type.BIGINT,
exp.DataType.Type.DECIMAL,
@@ -160,7 +218,10 @@ class TypeAnnotator:
# DATE < DATETIME < TIMESTAMP < TIMESTAMPTZ < TIMESTAMPLTZ
exp.DataType.Type.TIMESTAMPLTZ: set(),
exp.DataType.Type.TIMESTAMPTZ: {exp.DataType.Type.TIMESTAMPLTZ},
- exp.DataType.Type.TIMESTAMP: {exp.DataType.Type.TIMESTAMPTZ, exp.DataType.Type.TIMESTAMPLTZ},
+ exp.DataType.Type.TIMESTAMP: {
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ },
exp.DataType.Type.DATETIME: {
exp.DataType.Type.TIMESTAMP,
exp.DataType.Type.TIMESTAMPTZ,
@@ -219,7 +280,7 @@ class TypeAnnotator:
def _annotate_args(self, expression):
for value in expression.args.values():
- for v in ensure_list(value):
+ for v in ensure_collection(value):
self._maybe_annotate(v)
return expression
@@ -243,7 +304,9 @@ class TypeAnnotator:
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
expression.type = exp.DataType.Type.NULL
elif exp.DataType.Type.NULL in (left_type, right_type):
- expression.type = exp.DataType.build("NULLABLE", expressions=exp.DataType.build("BOOLEAN"))
+ expression.type = exp.DataType.build(
+ "NULLABLE", expressions=exp.DataType.build("BOOLEAN")
+ )
else:
expression.type = exp.DataType.Type.BOOLEAN
elif isinstance(expression, (exp.Condition, exp.Predicate)):
@@ -276,3 +339,17 @@ class TypeAnnotator:
def _annotate_with_type(self, expression, target_type):
expression.type = target_type
return self._annotate_args(expression)
+
+ def _annotate_by_args(self, expression, *args):
+ self._annotate_args(expression)
+ expressions = []
+ for arg in args:
+ arg_expr = expression.args.get(arg)
+ expressions.extend(expr for expr in ensure_list(arg_expr) if expr)
+
+ last_datatype = None
+ for expr in expressions:
+ last_datatype = self._maybe_coerce(last_datatype or expr.type, expr.type)
+
+ expression.type = last_datatype or exp.DataType.Type.UNKNOWN
+ return expression
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 0854336..29621af 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -60,7 +60,9 @@ def _join_is_used(scope, join, alias):
on_clause_columns = set(id(column) for column in on.find_all(exp.Column))
else:
on_clause_columns = set()
- return any(column for column in scope.source_columns(alias) if id(column) not in on_clause_columns)
+ return any(
+ column for column in scope.source_columns(alias) if id(column) not in on_clause_columns
+ )
def _is_joined_on_all_unique_outputs(scope, join):
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
index e30c263..8704e90 100644
--- a/sqlglot/optimizer/eliminate_subqueries.py
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -45,7 +45,13 @@ def eliminate_subqueries(expression):
# All table names are taken
for scope in root.traverse():
- taken.update({source.name: source for _, source in scope.sources.items() if isinstance(source, exp.Table)})
+ taken.update(
+ {
+ source.name: source
+ for _, source in scope.sources.items()
+ if isinstance(source, exp.Table)
+ }
+ )
# Map of Expression->alias
# Existing CTES in the root expression. We'll use this for deduplication.
@@ -70,7 +76,9 @@ def eliminate_subqueries(expression):
new_ctes.append(cte_scope.expression.parent)
# Now append the rest
- for scope in itertools.chain(root.union_scopes, root.subquery_scopes, root.derived_table_scopes):
+ for scope in itertools.chain(
+ root.union_scopes, root.subquery_scopes, root.derived_table_scopes
+ ):
for child_scope in scope.traverse():
new_cte = _eliminate(child_scope, existing_ctes, taken)
if new_cte:
diff --git a/sqlglot/optimizer/merge_subqueries.py b/sqlglot/optimizer/merge_subqueries.py
index 70e4629..9ae4966 100644
--- a/sqlglot/optimizer/merge_subqueries.py
+++ b/sqlglot/optimizer/merge_subqueries.py
@@ -122,7 +122,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
unmergable_window_columns = [
column
for column in outer_scope.columns
- if column.find_ancestor(exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc)
+ if column.find_ancestor(
+ exp.Where, exp.Group, exp.Order, exp.Join, exp.Having, exp.AggFunc
+ )
]
window_expressions_in_unmergable = [
column
@@ -147,7 +149,9 @@ def _mergeable(outer_scope, inner_select, leave_tables_isolated, from_or_join):
and not (
isinstance(from_or_join, exp.From)
and inner_select.args.get("where")
- and any(j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", []))
+ and any(
+ j.side in {"FULL", "RIGHT"} for j in outer_scope.expression.args.get("joins", [])
+ )
)
and not _is_a_window_expression_in_unmergable_operation()
)
@@ -203,7 +207,9 @@ def _merge_from(outer_scope, inner_scope, node_to_replace, alias):
if table.alias_or_name == node_to_replace.alias_or_name:
table.set("this", exp.to_identifier(new_subquery.alias_or_name))
outer_scope.remove_source(alias)
- outer_scope.add_source(new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name])
+ outer_scope.add_source(
+ new_subquery.alias_or_name, inner_scope.sources[new_subquery.alias_or_name]
+ )
def _merge_joins(outer_scope, inner_scope, from_or_join):
@@ -296,7 +302,9 @@ def _merge_order(outer_scope, inner_scope):
inner_scope (sqlglot.optimizer.scope.Scope)
"""
if (
- any(outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"])
+ any(
+ outer_scope.expression.args.get(arg) for arg in ["group", "distinct", "having", "order"]
+ )
or len(outer_scope.selected_sources) != 1
or any(expression.find(exp.AggFunc) for expression in outer_scope.expression.expressions)
):
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
index ab30d7a..db538ef 100644
--- a/sqlglot/optimizer/normalize.py
+++ b/sqlglot/optimizer/normalize.py
@@ -50,7 +50,9 @@ def normalization_distance(expression, dnf=False):
Returns:
int: difference
"""
- return sum(_predicate_lengths(expression, dnf)) - (len(list(expression.find_all(exp.Connector))) + 1)
+ return sum(_predicate_lengths(expression, dnf)) - (
+ len(list(expression.find_all(exp.Connector))) + 1
+ )
def _predicate_lengths(expression, dnf):
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 0c74e36..40e4ab1 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -68,4 +68,8 @@ def normalize(expression):
def other_table_names(join, exclude):
- return [name for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) if name != exclude]
+ return [
+ name
+ for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ if name != exclude
+ ]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 5ad8f46..b2ed062 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -58,6 +58,8 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
# Find any additional rule parameters, beyond `expression`
rule_params = rule.__code__.co_varnames
- rule_kwargs = {param: possible_kwargs[param] for param in rule_params if param in possible_kwargs}
+ rule_kwargs = {
+ param: possible_kwargs[param] for param in rule_params if param in possible_kwargs
+ }
expression = rule(expression, **rule_kwargs)
return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 583d059..6364f65 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -45,7 +45,11 @@ def pushdown(condition, sources, scope_ref_count):
condition = condition.replace(simplify(condition))
cnf_like = normalized(condition) or not normalized(condition, dnf=True)
- predicates = list(condition.flatten() if isinstance(condition, exp.And if cnf_like else exp.Or) else [condition])
+ predicates = list(
+ condition.flatten()
+ if isinstance(condition, exp.And if cnf_like else exp.Or)
+ else [condition]
+ )
if cnf_like:
pushdown_cnf(predicates, sources, scope_ref_count)
@@ -104,11 +108,17 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
for column in predicate.find_all(exp.Column):
if column.table == table:
condition = column.find_ancestor(exp.Condition)
- predicate_condition = exp.and_(predicate_condition, condition) if predicate_condition else condition
+ predicate_condition = (
+ exp.and_(predicate_condition, condition)
+ if predicate_condition
+ else condition
+ )
if predicate_condition:
conditions[table] = (
- exp.or_(conditions[table], predicate_condition) if table in conditions else predicate_condition
+ exp.or_(conditions[table], predicate_condition)
+ if table in conditions
+ else predicate_condition
)
for name, node in nodes.items():
@@ -146,10 +156,16 @@ def nodes_for_predicate(predicate, sources, scope_ref_count):
nodes[table] = node
elif isinstance(node, exp.Select) and len(tables) == 1:
# We can't push down window expressions
- has_window_expression = any(select for select in node.selects if select.find(exp.Window))
+ has_window_expression = any(
+ select for select in node.selects if select.find(exp.Window)
+ )
# we can't push down predicates to select statements if they are referenced in
# multiple places.
- if not node.args.get("group") and scope_ref_count[id(source)] < 2 and not has_window_expression:
+ if (
+ not node.args.get("group")
+ and scope_ref_count[id(source)] < 2
+ and not has_window_expression
+ ):
nodes[table] = node
return nodes
@@ -165,7 +181,7 @@ def replace_aliases(source, predicate):
def _replace_alias(column):
if isinstance(column, exp.Column) and column.name in aliases:
- return aliases[column.name]
+ return aliases[column.name].copy()
return column
return predicate.transform(_replace_alias)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
index 5820851..abd9492 100644
--- a/sqlglot/optimizer/pushdown_projections.py
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -98,7 +98,9 @@ def _remove_unused_selections(scope, parent_selections):
def _remove_indexed_selections(scope, indexes_to_remove):
- new_selections = [selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove]
+ new_selections = [
+ selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove
+ ]
if not new_selections:
new_selections.append(DEFAULT_SELECTION)
scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index ebee92a..69fe2b8 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -215,13 +215,21 @@ def _qualify_columns(scope, resolver):
# Determine whether each reference in the order by clause is to a column or an alias.
for ordered in scope.find_all(exp.Ordered):
for column in ordered.find_all(exp.Column):
- if not column.table and column.parent is not ordered and column.name in resolver.all_columns:
+ if (
+ not column.table
+ and column.parent is not ordered
+ and column.name in resolver.all_columns
+ ):
columns_missing_from_scope.append(column)
# Determine whether each reference in the having clause is to a column or an alias.
for having in scope.find_all(exp.Having):
for column in having.find_all(exp.Column):
- if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns:
+ if (
+ not column.table
+ and column.find_ancestor(exp.AggFunc)
+ and column.name in resolver.all_columns
+ ):
columns_missing_from_scope.append(column)
for column in columns_missing_from_scope:
@@ -295,7 +303,9 @@ def _qualify_outputs(scope):
"""Ensure all output columns are aliased"""
new_selections = []
- for i, (selection, aliased_column) in enumerate(itertools.zip_longest(scope.selects, scope.outer_column_list)):
+ for i, (selection, aliased_column) in enumerate(
+ itertools.zip_longest(scope.selects, scope.outer_column_list)
+ ):
if isinstance(selection, exp.Column):
# convoluted setter because a simple selection.replace(alias) would require a copy
alias_ = alias(exp.column(""), alias=selection.name)
@@ -343,14 +353,18 @@ class _Resolver:
(str) table name
"""
if self._unambiguous_columns is None:
- self._unambiguous_columns = self._get_unambiguous_columns(self._get_all_source_columns())
+ self._unambiguous_columns = self._get_unambiguous_columns(
+ self._get_all_source_columns()
+ )
return self._unambiguous_columns.get(column_name)
@property
def all_columns(self):
"""All available columns of all sources in this scope"""
if self._all_columns is None:
- self._all_columns = set(column for columns in self._get_all_source_columns().values() for column in columns)
+ self._all_columns = set(
+ column for columns in self._get_all_source_columns().values() for column in columns
+ )
return self._all_columns
def get_source_columns(self, name, only_visible=False):
@@ -377,7 +391,9 @@ class _Resolver:
def _get_all_source_columns(self):
if self._source_columns is None:
- self._source_columns = {k: self.get_source_columns(k) for k in self.scope.selected_sources}
+ self._source_columns = {
+ k: self.get_source_columns(k) for k in self.scope.selected_sources
+ }
return self._source_columns
def _get_unambiguous_columns(self, source_columns):
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
index 5a75ee2..18848f3 100644
--- a/sqlglot/optimizer/scope.py
+++ b/sqlglot/optimizer/scope.py
@@ -226,7 +226,9 @@ class Scope:
self._ensure_collected()
columns = self._raw_columns
- external_columns = [column for scope in self.subquery_scopes for column in scope.external_columns]
+ external_columns = [
+ column for scope in self.subquery_scopes for column in scope.external_columns
+ ]
named_outputs = {e.alias_or_name for e in self.expression.expressions}
@@ -278,7 +280,11 @@ class Scope:
Returns:
dict[str, Scope]: Mapping of source alias to Scope
"""
- return {alias: scope for alias, scope in self.sources.items() if isinstance(scope, Scope) and scope.is_cte}
+ return {
+ alias: scope
+ for alias, scope in self.sources.items()
+ if isinstance(scope, Scope) and scope.is_cte
+ }
@property
def selects(self):
@@ -307,7 +313,9 @@ class Scope:
sources in the current scope.
"""
if self._external_columns is None:
- self._external_columns = [c for c in self.columns if c.table not in self.selected_sources]
+ self._external_columns = [
+ c for c in self.columns if c.table not in self.selected_sources
+ ]
return self._external_columns
@property
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index c077906..d759e86 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -229,7 +229,9 @@ def simplify_literals(expression):
operands.append(a)
if len(operands) < size:
- return functools.reduce(lambda a, b: expression.__class__(this=a, expression=b), operands)
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
elif isinstance(expression, exp.Neg):
this = expression.this
if this.is_number:
@@ -255,6 +257,12 @@ def _simplify_binary(expression, a, b):
return TRUE if not_ else FALSE
if a == NULL:
return FALSE if not_ else TRUE
+ elif isinstance(expression, exp.NullSafeEQ):
+ if a == b:
+ return TRUE
+ elif isinstance(expression, exp.NullSafeNEQ):
+ if a == b:
+ return FALSE
elif NULL in (a, b):
return NULL
@@ -357,7 +365,7 @@ def extract_date(cast):
def extract_interval(interval):
try:
- from dateutil.relativedelta import relativedelta
+ from dateutil.relativedelta import relativedelta # type: ignore
except ModuleNotFoundError:
return None
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index 11c6eba..f41a84e 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -89,7 +89,11 @@ def decorrelate(select, parent_select, external_columns, sequence):
return
if isinstance(predicate, exp.Binary):
- key = predicate.right if any(node is column for node, *_ in predicate.left.walk()) else predicate.left
+ key = (
+ predicate.right
+ if any(node is column for node, *_ in predicate.left.walk())
+ else predicate.left
+ )
else:
return
@@ -145,7 +149,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
else:
parent_predicate = _replace(parent_predicate, "TRUE")
elif isinstance(parent_predicate, exp.All):
- parent_predicate = _replace(parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})")
+ parent_predicate = _replace(
+ parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
+ )
elif isinstance(parent_predicate, exp.Any):
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
@@ -168,7 +174,9 @@ def decorrelate(select, parent_select, external_columns, sequence):
if key in group_by:
key.replace(nested)
- parent_predicate = _replace(parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)")
+ parent_predicate = _replace(
+ parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
+ )
elif isinstance(predicate, exp.EQ):
parent_predicate = _replace(
parent_predicate,
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 79a1d90..bbea0e5 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -1,9 +1,13 @@
+from __future__ import annotations
+
import logging
+import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_errors
-from sqlglot.helper import apply_index_offset, ensure_list, list_get
+from sqlglot.helper import apply_index_offset, ensure_collection, seq_get
from sqlglot.tokens import Token, Tokenizer, TokenType
+from sqlglot.trie import in_trie, new_trie
logger = logging.getLogger("sqlglot")
@@ -20,7 +24,15 @@ def parse_var_map(args):
)
-class Parser:
+class _Parser(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+ klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
+ klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
+ return klass
+
+
+class Parser(metaclass=_Parser):
"""
Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
and produces a parsed syntax tree.
@@ -45,16 +57,16 @@ class Parser:
FUNCTIONS = {
**{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
"DATE_TO_DATE_STR": lambda args: exp.Cast(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"TIME_TO_TIME_STR": lambda args: exp.Cast(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
"TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
this=exp.Cast(
- this=list_get(args, 0),
+ this=seq_get(args, 0),
to=exp.DataType(this=exp.DataType.Type.TEXT),
),
start=exp.Literal.number(1),
@@ -90,6 +102,7 @@ class Parser:
TokenType.NVARCHAR,
TokenType.TEXT,
TokenType.BINARY,
+ TokenType.VARBINARY,
TokenType.JSON,
TokenType.INTERVAL,
TokenType.TIMESTAMP,
@@ -243,6 +256,7 @@ class Parser:
EQUALITY = {
TokenType.EQ: exp.EQ,
TokenType.NEQ: exp.NEQ,
+ TokenType.NULLSAFE_EQ: exp.NullSafeEQ,
}
COMPARISON = {
@@ -298,6 +312,21 @@ class Parser:
TokenType.ANTI,
}
+ LAMBDAS = {
+ TokenType.ARROW: lambda self, expressions: self.expression(
+ exp.Lambda,
+ this=self._parse_conjunction().transform(
+ self._replace_lambda, {node.name for node in expressions}
+ ),
+ expressions=expressions,
+ ),
+ TokenType.FARROW: lambda self, expressions: self.expression(
+ exp.Kwarg,
+ this=exp.Var(this=expressions[0].name),
+ expression=self._parse_conjunction(),
+ ),
+ }
+
COLUMN_OPERATORS = {
TokenType.DOT: None,
TokenType.DCOLON: lambda self, this, to: self.expression(
@@ -362,20 +391,30 @@ class Parser:
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
+ TokenType.USE: lambda self: self._parse_use(),
}
PRIMARY_PARSERS = {
- TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
- TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
- TokenType.STAR: lambda self, _: exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()}),
- TokenType.NULL: lambda *_: exp.Null(),
- TokenType.TRUE: lambda *_: exp.Boolean(this=True),
- TokenType.FALSE: lambda *_: exp.Boolean(this=False),
- TokenType.PARAMETER: lambda self, _: exp.Parameter(this=self._parse_var() or self._parse_primary()),
- TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
- TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
- TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
+ TokenType.STRING: lambda self, token: self.expression(
+ exp.Literal, this=token.text, is_string=True
+ ),
+ TokenType.NUMBER: lambda self, token: self.expression(
+ exp.Literal, this=token.text, is_string=False
+ ),
+ TokenType.STAR: lambda self, _: self.expression(
+ exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
+ ),
+ TokenType.NULL: lambda self, _: self.expression(exp.Null),
+ TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
+ TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
+ TokenType.PARAMETER: lambda self, _: self.expression(
+ exp.Parameter, this=self._parse_var() or self._parse_primary()
+ ),
+ TokenType.BIT_STRING: lambda self, token: self.expression(exp.BitString, this=token.text),
+ TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text),
+ TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
+ TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}
RANGE_PARSERS = {
@@ -411,16 +450,24 @@ class Parser:
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
TokenType.FORMAT: lambda self: self._parse_property_assignment(exp.FileFormatProperty),
- TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
+ TokenType.TABLE_FORMAT: lambda self: self._parse_property_assignment(
+ exp.TableFormatProperty
+ ),
TokenType.USING: lambda self: self._parse_property_assignment(exp.TableFormatProperty),
TokenType.LANGUAGE: lambda self: self._parse_property_assignment(exp.LanguageProperty),
TokenType.EXECUTE: lambda self: self._parse_execute_as(),
TokenType.DETERMINISTIC: lambda self: self.expression(
exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
),
- TokenType.IMMUTABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")),
- TokenType.STABLE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("STABLE")),
- TokenType.VOLATILE: lambda self: self.expression(exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")),
+ TokenType.IMMUTABLE: lambda self: self.expression(
+ exp.VolatilityProperty, this=exp.Literal.string("IMMUTABLE")
+ ),
+ TokenType.STABLE: lambda self: self.expression(
+ exp.VolatilityProperty, this=exp.Literal.string("STABLE")
+ ),
+ TokenType.VOLATILE: lambda self: self.expression(
+ exp.VolatilityProperty, this=exp.Literal.string("VOLATILE")
+ ),
}
CONSTRAINT_PARSERS = {
@@ -450,7 +497,8 @@ class Parser:
"group": lambda self: self._parse_group(),
"having": lambda self: self._parse_having(),
"qualify": lambda self: self._parse_qualify(),
- "window": lambda self: self._match(TokenType.WINDOW) and self._parse_window(self._parse_id_var(), alias=True),
+ "window": lambda self: self._match(TokenType.WINDOW)
+ and self._parse_window(self._parse_id_var(), alias=True),
"distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute),
"sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
"cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
@@ -459,6 +507,9 @@ class Parser:
"offset": lambda self: self._parse_offset(),
}
+ SHOW_PARSERS: t.Dict[str, t.Callable] = {}
+ SET_PARSERS: t.Dict[str, t.Callable] = {}
+
MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)
CREATABLES = {
@@ -488,7 +539,9 @@ class Parser:
"_curr",
"_next",
"_prev",
- "_greedy_subqueries",
+ "_prev_comment",
+ "_show_trie",
+ "_set_trie",
)
def __init__(
@@ -519,7 +572,7 @@ class Parser:
self._curr = None
self._next = None
self._prev = None
- self._greedy_subqueries = False
+ self._prev_comment = None
def parse(self, raw_tokens, sql=None):
"""
@@ -533,10 +586,12 @@ class Parser:
Returns
the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
"""
- return self._parse(parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql)
+ return self._parse(
+ parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
+ )
def parse_into(self, expression_types, raw_tokens, sql=None):
- for expression_type in ensure_list(expression_types):
+ for expression_type in ensure_collection(expression_types):
parser = self.EXPRESSION_PARSERS.get(expression_type)
if not parser:
raise TypeError(f"No parser registered for {expression_type}")
@@ -597,6 +652,9 @@ class Parser:
def expression(self, exp_class, **kwargs):
instance = exp_class(**kwargs)
+ if self._prev_comment:
+ instance.comment = self._prev_comment
+ self._prev_comment = None
self.validate_expression(instance)
return instance
@@ -633,14 +691,16 @@ class Parser:
return index
- def _get_token(self, index):
- return list_get(self._tokens, index)
-
def _advance(self, times=1):
self._index += times
- self._curr = self._get_token(self._index)
- self._next = self._get_token(self._index + 1)
- self._prev = self._get_token(self._index - 1) if self._index > 0 else None
+ self._curr = seq_get(self._tokens, self._index)
+ self._next = seq_get(self._tokens, self._index + 1)
+ if self._index > 0:
+ self._prev = self._tokens[self._index - 1]
+ self._prev_comment = self._prev.comment
+ else:
+ self._prev = None
+ self._prev_comment = None
def _retreat(self, index):
self._advance(index - self._index)
@@ -661,6 +721,7 @@ class Parser:
expression = self._parse_expression()
expression = self._parse_set_operations(expression) if expression else self._parse_select()
+
self._parse_query_modifiers(expression)
return expression
@@ -682,7 +743,11 @@ class Parser:
)
def _parse_exists(self, not_=False):
- return self._match(TokenType.IF) and (not not_ or self._match(TokenType.NOT)) and self._match(TokenType.EXISTS)
+ return (
+ self._match(TokenType.IF)
+ and (not not_ or self._match(TokenType.NOT))
+ and self._match(TokenType.EXISTS)
+ )
def _parse_create(self):
replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
@@ -931,7 +996,9 @@ class Parser:
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
- using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table(schema=True)),
+ using=self._parse_csv(
+ lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
+ ),
where=self._parse_where(),
)
@@ -983,11 +1050,13 @@ class Parser:
return None
def parse_values():
- k = self._parse_var()
+ key = self._parse_var()
+ value = None
+
if self._match(TokenType.EQ):
- v = self._parse_string()
- return (k, v)
- return (k, None)
+ value = self._parse_string()
+
+ return exp.Property(this=key, value=value)
self._match_l_paren()
values = self._parse_csv(parse_values)
@@ -1019,6 +1088,8 @@ class Parser:
self.raise_error(f"{this.key} does not support CTE")
this = cte
elif self._match(TokenType.SELECT):
+ comment = self._prev_comment
+
hint = self._parse_hint()
all_ = self._match(TokenType.ALL)
distinct = self._match(TokenType.DISTINCT)
@@ -1033,7 +1104,7 @@ class Parser:
self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
limit = self._parse_limit(top=True)
- expressions = self._parse_csv(lambda: self._parse_annotation(self._parse_expression()))
+ expressions = self._parse_csv(self._parse_expression)
this = self.expression(
exp.Select,
@@ -1042,6 +1113,7 @@ class Parser:
expressions=expressions,
limit=limit,
)
+ this.comment = comment
from_ = self._parse_from()
if from_:
this.set("from", from_)
@@ -1072,8 +1144,10 @@ class Parser:
while True:
expressions.append(self._parse_cte())
- if not self._match(TokenType.COMMA):
+ if not self._match(TokenType.COMMA) and not self._match(TokenType.WITH):
break
+ else:
+ self._match(TokenType.WITH)
return self.expression(
exp.With,
@@ -1111,11 +1185,7 @@ class Parser:
if not alias and not columns:
return None
- return self.expression(
- exp.TableAlias,
- this=alias,
- columns=columns,
- )
+ return self.expression(exp.TableAlias, this=alias, columns=columns)
def _parse_subquery(self, this):
return self.expression(
@@ -1150,12 +1220,6 @@ class Parser:
if expression:
this.set(key, expression)
- def _parse_annotation(self, expression):
- if self._match(TokenType.ANNOTATION):
- return self.expression(exp.Annotation, this=self._prev.text.strip(), expression=expression)
-
- return expression
-
def _parse_hint(self):
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
@@ -1295,7 +1359,9 @@ class Parser:
if not table:
self.raise_error("Expected table name")
- this = self.expression(exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots())
+ this = self.expression(
+ exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
+ )
if schema:
return self._parse_schema(this=this)
@@ -1500,7 +1566,9 @@ class Parser:
if not skip_order_token and not self._match(TokenType.ORDER_BY):
return this
- return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered))
+ return self.expression(
+ exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
+ )
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
@@ -1521,7 +1589,8 @@ class Parser:
if (
not explicitly_null_ordered
and (
- (asc and self.null_ordering == "nulls_are_small") or (desc and self.null_ordering != "nulls_are_small")
+ (asc and self.null_ordering == "nulls_are_small")
+ or (desc and self.null_ordering != "nulls_are_small")
)
and self.null_ordering != "nulls_are_last"
):
@@ -1606,6 +1675,9 @@ class Parser:
def _parse_is(self, this):
negate = self._match(TokenType.NOT)
+ if self._match(TokenType.DISTINCT_FROM):
+ klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
+ return self.expression(klass, this=this, expression=self._parse_expression())
this = self.expression(
exp.Is,
this=this,
@@ -1653,9 +1725,13 @@ class Parser:
expression=self._parse_term(),
)
elif self._match_pair(TokenType.LT, TokenType.LT):
- this = self.expression(exp.BitwiseLeftShift, this=this, expression=self._parse_term())
+ this = self.expression(
+ exp.BitwiseLeftShift, this=this, expression=self._parse_term()
+ )
elif self._match_pair(TokenType.GT, TokenType.GT):
- this = self.expression(exp.BitwiseRightShift, this=this, expression=self._parse_term())
+ this = self.expression(
+ exp.BitwiseRightShift, this=this, expression=self._parse_term()
+ )
else:
break
@@ -1685,7 +1761,7 @@ class Parser:
)
index = self._index
- type_token = self._parse_types()
+ type_token = self._parse_types(check_func=True)
this = self._parse_column()
if type_token:
@@ -1698,7 +1774,7 @@ class Parser:
return this
- def _parse_types(self):
+ def _parse_types(self, check_func=False):
index = self._index
if not self._match_set(self.TYPE_TOKENS):
@@ -1708,10 +1784,13 @@ class Parser:
nested = type_token in self.NESTED_TYPE_TOKENS
is_struct = type_token == TokenType.STRUCT
expressions = None
+ maybe_func = False
if not nested and self._match_pair(TokenType.L_BRACKET, TokenType.R_BRACKET):
return exp.DataType(
- this=exp.DataType.Type.ARRAY, expressions=[exp.DataType.build(type_token.value)], nested=True
+ this=exp.DataType.Type.ARRAY,
+ expressions=[exp.DataType.build(type_token.value)],
+ nested=True,
)
if self._match(TokenType.L_BRACKET):
@@ -1731,6 +1810,7 @@ class Parser:
return None
self._match_r_paren()
+ maybe_func = True
if nested and self._match(TokenType.LT):
if is_struct:
@@ -1741,25 +1821,46 @@ class Parser:
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
+ value = None
if type_token in self.TIMESTAMPS:
- tz = self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ
- if tz:
- return exp.DataType(
+ if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
+ value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPTZ,
expressions=expressions,
)
- ltz = self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
- if ltz:
- return exp.DataType(
+ elif (
+ self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
+ ):
+ value = exp.DataType(
this=exp.DataType.Type.TIMESTAMPLTZ,
expressions=expressions,
)
- self._match(TokenType.WITHOUT_TIME_ZONE)
+ elif self._match(TokenType.WITHOUT_TIME_ZONE):
+ value = exp.DataType(
+ this=exp.DataType.Type.TIMESTAMP,
+ expressions=expressions,
+ )
- return exp.DataType(
- this=exp.DataType.Type.TIMESTAMP,
- expressions=expressions,
- )
+ maybe_func = maybe_func and value is None
+
+ if value is None:
+ value = exp.DataType(
+ this=exp.DataType.Type.TIMESTAMP,
+ expressions=expressions,
+ )
+
+ if maybe_func and check_func:
+ index2 = self._index
+ peek = self._parse_string()
+
+ if not peek:
+ self._retreat(index)
+ return None
+
+ self._retreat(index2)
+
+ if value:
+ return value
return exp.DataType(
this=exp.DataType.Type[type_token.value.upper()],
@@ -1826,22 +1927,29 @@ class Parser:
return exp.Literal.number(f"0.{self._prev.text}")
if self._match(TokenType.L_PAREN):
+ comment = self._prev_comment
query = self._parse_select()
if query:
expressions = [query]
else:
- expressions = self._parse_csv(lambda: self._parse_alias(self._parse_conjunction(), explicit=True))
+ expressions = self._parse_csv(
+ lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
+ )
- this = list_get(expressions, 0)
+ this = seq_get(expressions, 0)
self._parse_query_modifiers(this)
self._match_r_paren()
if isinstance(this, exp.Subqueryable):
- return self._parse_set_operations(self._parse_subquery(this))
- if len(expressions) > 1:
- return self.expression(exp.Tuple, expressions=expressions)
- return self.expression(exp.Paren, this=this)
+ this = self._parse_set_operations(self._parse_subquery(this))
+ elif len(expressions) > 1:
+ this = self.expression(exp.Tuple, expressions=expressions)
+ else:
+ this = self.expression(exp.Paren, this=this)
+ if comment:
+ this.comment = comment
+ return this
return None
@@ -1894,7 +2002,8 @@ class Parser:
self.validate_expression(this, args)
else:
this = self.expression(exp.Anonymous, this=this, expressions=args)
- self._match_r_paren()
+
+ self._match_r_paren(this)
return self._parse_window(this)
def _parse_user_defined_function(self):
@@ -1920,6 +2029,18 @@ class Parser:
return self.expression(exp.Identifier, this=token.text)
+ def _parse_session_parameter(self):
+ kind = None
+ this = self._parse_id_var() or self._parse_primary()
+ if self._match(TokenType.DOT):
+ kind = this.name
+ this = self._parse_var() or self._parse_primary()
+ return self.expression(
+ exp.SessionParameter,
+ this=this,
+ kind=kind,
+ )
+
def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
@@ -1938,27 +2059,24 @@ class Parser:
else:
expressions = [self._parse_id_var()]
- if not self._match(TokenType.ARROW):
- self._retreat(index)
+ if self._match_set(self.LAMBDAS):
+ return self.LAMBDAS[self._prev.token_type](self, expressions)
- if self._match(TokenType.DISTINCT):
- this = self.expression(exp.Distinct, expressions=self._parse_csv(self._parse_conjunction))
- else:
- this = self._parse_conjunction()
+ self._retreat(index)
- if self._match(TokenType.IGNORE_NULLS):
- this = self.expression(exp.IgnoreNulls, this=this)
- else:
- self._match(TokenType.RESPECT_NULLS)
+ if self._match(TokenType.DISTINCT):
+ this = self.expression(
+ exp.Distinct, expressions=self._parse_csv(self._parse_conjunction)
+ )
+ else:
+ this = self._parse_conjunction()
- return self._parse_alias(self._parse_limit(self._parse_order(this)))
+ if self._match(TokenType.IGNORE_NULLS):
+ this = self.expression(exp.IgnoreNulls, this=this)
+ else:
+ self._match(TokenType.RESPECT_NULLS)
- conjunction = self._parse_conjunction().transform(self._replace_lambda, {node.name for node in expressions})
- return self.expression(
- exp.Lambda,
- this=conjunction,
- expressions=expressions,
- )
+ return self._parse_alias(self._parse_limit(self._parse_order(this)))
def _parse_schema(self, this=None):
index = self._index
@@ -1966,7 +2084,9 @@ class Parser:
self._retreat(index)
return this
- args = self._parse_csv(lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)))
+ args = self._parse_csv(
+ lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True))
+ )
self._match_r_paren()
return self.expression(exp.Schema, this=this, expressions=args)
@@ -2104,6 +2224,7 @@ class Parser:
if not self._match(TokenType.R_BRACKET):
self.raise_error("Expected ]")
+ this.comment = self._prev_comment
return self._parse_bracket(this)
def _parse_case(self):
@@ -2124,7 +2245,9 @@ class Parser:
if not self._match(TokenType.END):
self.raise_error("Expected END after CASE", self._prev)
- return self._parse_window(self.expression(exp.Case, this=expression, ifs=ifs, default=default))
+ return self._parse_window(
+ self.expression(exp.Case, this=expression, ifs=ifs, default=default)
+ )
def _parse_if(self):
if self._match(TokenType.L_PAREN):
@@ -2331,7 +2454,9 @@ class Parser:
self._match(TokenType.BETWEEN)
return {
- "value": (self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text)
+ "value": (
+ self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) and self._prev.text
+ )
or self._parse_bitwise(),
"side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) and self._prev.text,
}
@@ -2348,7 +2473,7 @@ class Parser:
this=this,
expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
)
- self._match_r_paren()
+ self._match_r_paren(aliases)
return aliases
alias = self._parse_id_var(any_token)
@@ -2365,28 +2490,29 @@ class Parser:
return identifier
if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS:
- return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
-
- return self._match_set(tokens or self.ID_VAR_TOKENS) and exp.Identifier(this=self._prev.text, quoted=False)
+ self._advance()
+ elif not self._match_set(tokens or self.ID_VAR_TOKENS):
+ return None
+ return exp.Identifier(this=self._prev.text, quoted=False)
def _parse_string(self):
if self._match(TokenType.STRING):
- return exp.Literal.string(self._prev.text)
+ return self.PRIMARY_PARSERS[TokenType.STRING](self, self._prev)
return self._parse_placeholder()
def _parse_number(self):
if self._match(TokenType.NUMBER):
- return exp.Literal.number(self._prev.text)
+ return self.PRIMARY_PARSERS[TokenType.NUMBER](self, self._prev)
return self._parse_placeholder()
def _parse_identifier(self):
if self._match(TokenType.IDENTIFIER):
- return exp.Identifier(this=self._prev.text, quoted=True)
+ return self.expression(exp.Identifier, this=self._prev.text, quoted=True)
return self._parse_placeholder()
def _parse_var(self):
if self._match(TokenType.VAR):
- return exp.Var(this=self._prev.text)
+ return self.expression(exp.Var, this=self._prev.text)
return self._parse_placeholder()
def _parse_var_or_string(self):
@@ -2394,27 +2520,27 @@ class Parser:
def _parse_null(self):
if self._match(TokenType.NULL):
- return exp.Null()
+ return self.PRIMARY_PARSERS[TokenType.NULL](self, self._prev)
return None
def _parse_boolean(self):
if self._match(TokenType.TRUE):
- return exp.Boolean(this=True)
+ return self.PRIMARY_PARSERS[TokenType.TRUE](self, self._prev)
if self._match(TokenType.FALSE):
- return exp.Boolean(this=False)
+ return self.PRIMARY_PARSERS[TokenType.FALSE](self, self._prev)
return None
def _parse_star(self):
if self._match(TokenType.STAR):
- return exp.Star(**{"except": self._parse_except(), "replace": self._parse_replace()})
+ return self.PRIMARY_PARSERS[TokenType.STAR](self, self._prev)
return None
def _parse_placeholder(self):
if self._match(TokenType.PLACEHOLDER):
- return exp.Placeholder()
+ return self.expression(exp.Placeholder)
elif self._match(TokenType.COLON):
self._advance()
- return exp.Placeholder(this=self._prev.text)
+ return self.expression(exp.Placeholder, this=self._prev.text)
return None
def _parse_except(self):
@@ -2432,22 +2558,27 @@ class Parser:
self._match_r_paren()
return columns
- def _parse_csv(self, parse):
- parse_result = parse()
+ def _parse_csv(self, parse_method):
+ parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
while self._match(TokenType.COMMA):
- parse_result = parse()
+ if parse_result and self._prev_comment is not None:
+ parse_result.comment = self._prev_comment
+
+ parse_result = parse_method()
if parse_result is not None:
items.append(parse_result)
return items
- def _parse_tokens(self, parse, expressions):
- this = parse()
+ def _parse_tokens(self, parse_method, expressions):
+ this = parse_method()
while self._match_set(expressions):
- this = self.expression(expressions[self._prev.token_type], this=this, expression=parse())
+ this = self.expression(
+ expressions[self._prev.token_type], this=this, expression=parse_method()
+ )
return this
@@ -2460,6 +2591,47 @@ class Parser:
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
+ def _parse_use(self):
+ return self.expression(exp.Use, this=self._parse_id_var())
+
+ def _parse_show(self):
+ parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
+ if parser:
+ return parser(self)
+ self._advance()
+ return self.expression(exp.Show, this=self._prev.text.upper())
+
+ def _default_parse_set_item(self):
+ return self.expression(
+ exp.SetItem,
+ this=self._parse_statement(),
+ )
+
+ def _parse_set_item(self):
+ parser = self._find_parser(self.SET_PARSERS, self._set_trie)
+ return parser(self) if parser else self._default_parse_set_item()
+
+ def _parse_set(self):
+ return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))
+
+ def _find_parser(self, parsers, trie):
+ index = self._index
+ this = []
+ while True:
+ # The current token might be multiple words
+ curr = self._curr.text.upper()
+ key = curr.split(" ")
+ this.append(curr)
+ self._advance()
+ result, trie = in_trie(trie, key)
+ if result == 0:
+ break
+ if result == 2:
+ subparser = parsers[" ".join(this)]
+ return subparser
+ self._retreat(index)
+ return None
+
def _match(self, token_type):
if not self._curr:
return None
@@ -2491,13 +2663,17 @@ class Parser:
return None
- def _match_l_paren(self):
+ def _match_l_paren(self, expression=None):
if not self._match(TokenType.L_PAREN):
self.raise_error("Expecting (")
+ if expression and self._prev_comment:
+ expression.comment = self._prev_comment
- def _match_r_paren(self):
+ def _match_r_paren(self, expression=None):
if not self._match(TokenType.R_PAREN):
self.raise_error("Expecting )")
+ if expression and self._prev_comment:
+ expression.comment = self._prev_comment
def _match_text(self, *texts):
index = self._index
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index ea995d8..cd1de5e 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -72,7 +72,9 @@ class Step:
if from_:
from_ = from_.expressions
if len(from_) > 1:
- raise UnsupportedError("Multi-from statements are unsupported. Run it through the optimizer")
+ raise UnsupportedError(
+ "Multi-from statements are unsupported. Run it through the optimizer"
+ )
step = Scan.from_expression(from_[0], ctes)
else:
@@ -102,7 +104,7 @@ class Step:
continue
if operand not in operands:
operands[operand] = f"_a_{next(sequence)}"
- operand.replace(exp.column(operands[operand], step.name, quoted=True))
+ operand.replace(exp.column(operands[operand], quoted=True))
else:
projections.append(e)
@@ -117,9 +119,11 @@ class Step:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
- aggregate.operands = tuple(alias(operand, alias_) for operand, alias_ in operands.items())
+ aggregate.operands = tuple(
+ alias(operand, alias_) for operand, alias_ in operands.items()
+ )
aggregate.aggregations = aggregations
- aggregate.group = [exp.column(e.alias_or_name, step.name, quoted=True) for e in group.expressions]
+ aggregate.group = group.expressions
aggregate.add_dependency(step)
step = aggregate
@@ -136,9 +140,6 @@ class Step:
sort.key = order.expressions
sort.add_dependency(step)
step = sort
- for k in sort.key + projections:
- for column in k.find_all(exp.Column):
- column.set("table", exp.to_identifier(step.name, quoted=True))
step.projections = projections
@@ -203,7 +204,9 @@ class Scan(Step):
alias_ = expression.alias
if not alias_:
- raise UnsupportedError("Tables/Subqueries must be aliased. Run it through the optimizer")
+ raise UnsupportedError(
+ "Tables/Subqueries must be aliased. Run it through the optimizer"
+ )
if isinstance(expression, exp.Subquery):
table = expression.this
diff --git a/sqlglot/py.typed b/sqlglot/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/sqlglot/py.typed
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index c916330..fcf7291 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -1,44 +1,60 @@
+from __future__ import annotations
+
import abc
+import typing as t
from sqlglot import expressions as exp
-from sqlglot.errors import OptimizeError
+from sqlglot.errors import SchemaError
from sqlglot.helper import csv_reader
+from sqlglot.trie import in_trie, new_trie
+
+if t.TYPE_CHECKING:
+ from sqlglot.dataframe.sql.types import StructType
+
+ ColumnMapping = t.Union[t.Dict, str, StructType, t.List]
+
+TABLE_ARGS = ("this", "db", "catalog")
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@abc.abstractmethod
- def add_table(self, table, column_mapping=None):
+ def add_table(
+ self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ ) -> None:
"""
- Register or update a table. Some implementing classes may require column information to also be provided
+ Register or update a table. Some implementing classes may require column information to also be provided.
Args:
- table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
- column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
+ table: table expression instance or string representing the table.
+ column_mapping: a column mapping that describes the structure of the table.
"""
@abc.abstractmethod
- def column_names(self, table, only_visible=False):
+ def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
"""
Get the column names for a table.
+
Args:
- table (sqlglot.expressions.Table): Table expression instance
- only_visible (bool): Whether to include invisible columns
+ table: the `Table` expression instance.
+ only_visible: whether to include invisible columns.
+
Returns:
- list[str]: list of column names
+ The list of column names.
"""
@abc.abstractmethod
- def get_column_type(self, table, column):
+ def get_column_type(self, table: exp.Table | str, column: exp.Column) -> exp.DataType.Type:
"""
- Get the exp.DataType type of a column in the schema.
+ Get the :class:`sqlglot.exp.DataType` type of a column in the schema.
Args:
- table (sqlglot.expressions.Table): The source table.
- column (sqlglot.expressions.Column): The target column.
+ table: the source table.
+ column: the target column.
+
Returns:
- sqlglot.expressions.DataType.Type: The resulting column type.
+ The resulting column type.
"""
@@ -60,132 +76,179 @@ class MappingSchema(Schema):
dialect (str): The dialect to be used for custom type mappings.
"""
- def __init__(self, schema=None, visible=None, dialect=None):
+ def __init__(
+ self,
+ schema: t.Optional[t.Dict] = None,
+ visible: t.Optional[t.Dict] = None,
+ dialect: t.Optional[str] = None,
+ ) -> None:
self.schema = schema or {}
- self.visible = visible
+ self.visible = visible or {}
+ self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect
- self._type_mapping_cache = {}
- self.supported_table_args = []
- self.forbidden_table_args = set()
- if self.schema:
- self._initialize_supported_args()
+ self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
+ self._supported_table_args: t.Tuple[str, ...] = tuple()
@classmethod
- def from_mapping_schema(cls, mapping_schema):
+ def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
+ return MappingSchema(
+ schema=mapping_schema.schema,
+ visible=mapping_schema.visible,
+ dialect=mapping_schema.dialect,
+ )
+
+ def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
- schema=mapping_schema.schema, visible=mapping_schema.visible, dialect=mapping_schema.dialect
+ **{ # type: ignore
+ "schema": self.schema.copy(),
+ "visible": self.visible.copy(),
+ "dialect": self.dialect,
+ **kwargs,
+ }
)
- def copy(self, **kwargs):
- return MappingSchema(**{"schema": self.schema.copy(), **kwargs})
+ @property
+ def supported_table_args(self):
+ if not self._supported_table_args and self.schema:
+ depth = _dict_depth(self.schema)
- def add_table(self, table, column_mapping=None):
+ if not depth or depth == 1: # {}
+ self._supported_table_args = tuple()
+ elif 2 <= depth <= 4:
+ self._supported_table_args = TABLE_ARGS[: depth - 1]
+ else:
+ raise SchemaError(f"Invalid schema shape. Depth: {depth}")
+
+ return self._supported_table_args
+
+ def add_table(
+ self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
+ ) -> None:
"""
Register or update a table. Updates are only performed if a new column mapping is provided.
Args:
- table (sqlglot.expressions.Table|str): Table expression instance or string representing the table
- column_mapping (dict|str|sqlglot.dataframe.sql.types.StructType|list): A column mapping that describes the structure of the table
+ table: the `Table` expression instance or string representing the table.
+ column_mapping: a column mapping that describes the structure of the table.
"""
- table = exp.to_table(table)
- self._validate_table(table)
+ table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
- table_args = [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)]
- existing_column_mapping = _nested_get(
- self.schema, *zip(self.supported_table_args, table_args), raise_on_missing=False
- )
- if existing_column_mapping and not column_mapping:
+ schema = self.find_schema(table_, raise_on_missing=False)
+
+ if schema and not column_mapping:
return
+
_nested_set(
self.schema,
- [table.text(p) for p in self.supported_table_args or self._get_table_args_from_table(table)],
+ list(reversed(self.table_parts(table_))),
column_mapping,
)
- self._initialize_supported_args()
+ self.schema_trie = self._build_trie(self.schema)
- def _get_table_args_from_table(self, table):
- if table.args.get("catalog") is not None:
- return "catalog", "db", "this"
- if table.args.get("db") is not None:
- return "db", "this"
- return ("this",)
+ def _ensure_table(self, table: exp.Table | str) -> exp.Table:
+ table_ = exp.to_table(table)
- def _validate_table(self, table):
- if not self.supported_table_args and isinstance(table, exp.Table):
- return
- for forbidden in self.forbidden_table_args:
- if table.text(forbidden):
- raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
- for expected in self.supported_table_args:
- if not table.text(expected):
- raise ValueError(f"Table is expected to have {expected}. Received: {table.sql()} ")
+ if not table_:
+ raise SchemaError(f"Not a valid table '{table}'")
+
+ return table_
+
+ def table_parts(self, table: exp.Table) -> t.List[str]:
+ return [table.text(part) for part in TABLE_ARGS if table.text(part)]
+
+ def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
+ table_ = self._ensure_table(table)
- def column_names(self, table, only_visible=False):
- table = exp.to_table(table)
- if not isinstance(table.this, exp.Identifier):
- return fs_get(table)
+ if not isinstance(table_.this, exp.Identifier):
+ return fs_get(table) # type: ignore
- args = tuple(table.text(p) for p in self.supported_table_args)
+ schema = self.find_schema(table_)
- for forbidden in self.forbidden_table_args:
- if table.text(forbidden):
- raise ValueError(f"Schema doesn't support {forbidden}. Received: {table.sql()}")
+ if schema is None:
+ raise SchemaError(f"Could not find table schema {table}")
- columns = list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
if not only_visible or not self.visible:
- return columns
+ return list(schema)
- visible = _nested_get(self.visible, *zip(self.supported_table_args, args))
- return [col for col in columns if col in visible]
+ visible = self._nested_get(self.table_parts(table_), self.visible)
+ return [col for col in schema if col in visible] # type: ignore
- def get_column_type(self, table, column):
- try:
- schema_type = self.schema.get(table.name, {}).get(column.name).upper()
+ def find_schema(
+ self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
+ ) -> t.Optional[t.Dict[str, str]]:
+ parts = self.table_parts(table)[0 : len(self.supported_table_args)]
+ value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
+
+ if value == 0:
+ if raise_on_missing:
+ raise SchemaError(f"Cannot find schema for {table}.")
+ else:
+ return None
+ elif value == 1:
+ possibilities = flatten_schema(trie)
+ if len(possibilities) == 1:
+ parts.extend(possibilities[0])
+ else:
+ message = ", ".join(".".join(parts) for parts in possibilities)
+ if raise_on_missing:
+ raise SchemaError(f"Ambiguous schema for {table}: {message}.")
+ return None
+
+ return self._nested_get(parts, raise_on_missing=raise_on_missing)
+
+ def get_column_type(
+ self, table: exp.Table | str, column: exp.Column | str
+ ) -> exp.DataType.Type:
+ column_name = column if isinstance(column, str) else column.name
+ table_ = exp.to_table(table)
+ if table_:
+ table_schema = self.find_schema(table_)
+ schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
- except:
- raise OptimizeError(f"Failed to get type for column {column.sql()}")
+ raise SchemaError(f"Could not convert table '{table}'")
- def _convert_type(self, schema_type):
+ def _convert_type(self, schema_type: str) -> exp.DataType.Type:
"""
- Convert a type represented as a string to the corresponding exp.DataType.Type object.
+ Convert a type represented as a string to the corresponding :class:`sqlglot.exp.DataType` object.
+
Args:
- schema_type (str): The type we want to convert.
+ schema_type: the type we want to convert.
+
Returns:
- sqlglot.expressions.DataType.Type: The resulting expression type.
+ The resulting expression type.
"""
if schema_type not in self._type_mapping_cache:
try:
- self._type_mapping_cache[schema_type] = exp.maybe_parse(
- schema_type, into=exp.DataType, dialect=self.dialect
- ).this
+ expression = exp.maybe_parse(schema_type, into=exp.DataType, dialect=self.dialect)
+ if expression is None:
+ raise ValueError(f"Could not parse {schema_type}")
+ self._type_mapping_cache[schema_type] = expression.this
except AttributeError:
- raise OptimizeError(f"Failed to convert type {schema_type}")
+ raise SchemaError(f"Failed to convert type {schema_type}")
return self._type_mapping_cache[schema_type]
- def _initialize_supported_args(self):
- if not self.supported_table_args:
- depth = _dict_depth(self.schema)
-
- all_args = ["this", "db", "catalog"]
- if not depth or depth == 1: # {}
- self.supported_table_args = []
- elif 2 <= depth <= 4:
- self.supported_table_args = tuple(reversed(all_args[: depth - 1]))
- else:
- raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
+ def _build_trie(self, schema: t.Dict):
+ return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
- self.forbidden_table_args = {"catalog", "db", "this"} - set(self.supported_table_args)
+ def _nested_get(
+ self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
+ ) -> t.Optional[t.Any]:
+ return _nested_get(
+ d or self.schema,
+ *zip(self.supported_table_args, reversed(parts)),
+ raise_on_missing=raise_on_missing,
+ )
-def ensure_schema(schema):
+def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
return schema
return MappingSchema(schema)
-def ensure_column_mapping(mapping):
+def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
if isinstance(mapping, dict):
return mapping
elif isinstance(mapping, str):
@@ -196,7 +259,7 @@ def ensure_column_mapping(mapping):
}
# Check if mapping looks like a DataFrame StructType
elif hasattr(mapping, "simpleString"):
- return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping}
+ return {struct_field.name: struct_field.dataType.simpleString() for struct_field in mapping} # type: ignore
elif isinstance(mapping, list):
return {x.strip(): None for x in mapping}
elif mapping is None:
@@ -204,7 +267,20 @@ def ensure_column_mapping(mapping):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
-def fs_get(table):
+def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
+ tables = []
+ keys = keys or []
+ depth = _dict_depth(schema)
+
+ for k, v in schema.items():
+ if depth >= 3:
+ tables.extend(flatten_schema(v, keys + [k]))
+ elif depth == 2:
+ tables.append(keys + [k])
+ return tables
+
+
+def fs_get(table: exp.Table) -> t.List[str]:
name = table.this.name
if name.upper() == "READ_CSV":
@@ -214,21 +290,23 @@ def fs_get(table):
raise ValueError(f"Cannot read schema for {table}")
-def _nested_get(d, *path, raise_on_missing=True):
+def _nested_get(
+ d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
+) -> t.Optional[t.Any]:
"""
Get a value for a nested dictionary.
Args:
- d (dict): dictionary
- *path (tuple[str, str]): tuples of (name, key)
+ d: the dictionary to search.
+ *path: tuples of (name, key), where:
`key` is the key in the dictionary to get.
`name` is a string to use in the error if `key` isn't found.
Returns:
- The value or None if it doesn't exist
+ The value or None if it doesn't exist.
"""
for name, key in path:
- d = d.get(key)
+ d = d.get(key) # type: ignore
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
@@ -237,36 +315,44 @@ def _nested_get(d, *path, raise_on_missing=True):
return d
-def _nested_set(d, keys, value):
+def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
"""
In-place set a value for a nested dictionary
- Ex:
+ Example:
>>> _nested_set({}, ["top_key", "second_key"], "value")
{'top_key': {'second_key': 'value'}}
+
>>> _nested_set({"top_key": {"third_key": "third_value"}}, ["top_key", "second_key"], "value")
{'top_key': {'third_key': 'third_value', 'second_key': 'value'}}
- d (dict): dictionary
- keys (Iterable[str]): ordered iterable of keys that makeup path to value
- value (Any): The value to set in the dictionary for the given key path
+ Args:
+ d: dictionary to update.
+ keys: the keys that makeup the path to `value`.
+ value: the value to set in the dictionary for the given key path.
+
+ Returns:
+ The (possibly) updated dictionary.
"""
if not keys:
- return
+ return d
+
if len(keys) == 1:
d[keys[0]] = value
- return
+ return d
+
subd = d
for key in keys[:-1]:
if key not in subd:
subd = subd.setdefault(key, {})
else:
subd = subd[key]
+
subd[keys[-1]] = value
return d
-def _dict_depth(d):
+def _dict_depth(d: t.Dict) -> int:
"""
Get the nesting depth of a dictionary.
diff --git a/sqlglot/time.py b/sqlglot/time.py
index 729b50d..97726b3 100644
--- a/sqlglot/time.py
+++ b/sqlglot/time.py
@@ -1,9 +1,13 @@
-# the generic time format is based on python time.strftime
+import typing as t
+
+# The generic time format is based on python time.strftime.
# https://docs.python.org/3/library/time.html#time.strftime
from sqlglot.trie import in_trie, new_trie
-def format_time(string, mapping, trie=None):
+def format_time(
+ string: str, mapping: t.Dict[str, str], trie: t.Optional[t.Dict] = None
+) -> t.Optional[str]:
"""
Converts a time string given a mapping.
@@ -11,11 +15,16 @@ def format_time(string, mapping, trie=None):
>>> format_time("%Y", {"%Y": "YYYY"})
'YYYY'
- mapping: Dictionary of time format to target time format
- trie: Optional trie, can be passed in for performance
+ Args:
+ mapping: dictionary of time format to target time format.
+ trie: optional trie, can be passed in for performance.
+
+ Returns:
+ The converted time string.
"""
if not string:
return None
+
start = 0
end = 1
size = len(string)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 766c01a..95d84d6 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import typing as t
from enum import auto
from sqlglot.helper import AutoName
@@ -27,6 +30,7 @@ class TokenType(AutoName):
NOT = auto()
EQ = auto()
NEQ = auto()
+ NULLSAFE_EQ = auto()
AND = auto()
OR = auto()
AMP = auto()
@@ -36,12 +40,14 @@ class TokenType(AutoName):
TILDA = auto()
ARROW = auto()
DARROW = auto()
+ FARROW = auto()
+ HASH = auto()
HASH_ARROW = auto()
DHASH_ARROW = auto()
LR_ARROW = auto()
- ANNOTATION = auto()
DOLLAR = auto()
PARAMETER = auto()
+ SESSION_PARAMETER = auto()
SPACE = auto()
BREAK = auto()
@@ -73,7 +79,7 @@ class TokenType(AutoName):
NVARCHAR = auto()
TEXT = auto()
BINARY = auto()
- BYTEA = auto()
+ VARBINARY = auto()
JSON = auto()
TIMESTAMP = auto()
TIMESTAMPTZ = auto()
@@ -142,6 +148,7 @@ class TokenType(AutoName):
DESCRIBE = auto()
DETERMINISTIC = auto()
DISTINCT = auto()
+ DISTINCT_FROM = auto()
DISTRIBUTE_BY = auto()
DIV = auto()
DROP = auto()
@@ -238,6 +245,7 @@ class TokenType(AutoName):
RETURNS = auto()
RIGHT = auto()
RLIKE = auto()
+ ROLLBACK = auto()
ROLLUP = auto()
ROW = auto()
ROWS = auto()
@@ -287,37 +295,49 @@ class TokenType(AutoName):
class Token:
- __slots__ = ("token_type", "text", "line", "col")
+ __slots__ = ("token_type", "text", "line", "col", "comment")
@classmethod
- def number(cls, number):
+ def number(cls, number: int) -> Token:
+ """Returns a NUMBER token with `number` as its text."""
return cls(TokenType.NUMBER, str(number))
@classmethod
- def string(cls, string):
+ def string(cls, string: str) -> Token:
+ """Returns a STRING token with `string` as its text."""
return cls(TokenType.STRING, string)
@classmethod
- def identifier(cls, identifier):
+ def identifier(cls, identifier: str) -> Token:
+ """Returns an IDENTIFIER token with `identifier` as its text."""
return cls(TokenType.IDENTIFIER, identifier)
@classmethod
- def var(cls, var):
+ def var(cls, var: str) -> Token:
+ """Returns an VAR token with `var` as its text."""
return cls(TokenType.VAR, var)
- def __init__(self, token_type, text, line=1, col=1):
+ def __init__(
+ self,
+ token_type: TokenType,
+ text: str,
+ line: int = 1,
+ col: int = 1,
+ comment: t.Optional[str] = None,
+ ) -> None:
self.token_type = token_type
self.text = text
self.line = line
self.col = max(col - len(text), 1)
+ self.comment = comment
- def __repr__(self):
+ def __repr__(self) -> str:
attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
return f"<Token {attributes}>"
class _Tokenizer(type):
- def __new__(cls, clsname, bases, attrs):
+ def __new__(cls, clsname, bases, attrs): # type: ignore
klass = super().__new__(cls, clsname, bases, attrs)
klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES)
@@ -325,27 +345,29 @@ class _Tokenizer(type):
klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS)
klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS)
klass._IDENTIFIERS = cls._delimeter_list_to_dict(klass.IDENTIFIERS)
+ klass._ESCAPES = set(klass.ESCAPES)
klass._COMMENTS = dict(
- (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) for comment in klass.COMMENTS
+ (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
+ for comment in klass.COMMENTS
)
klass.KEYWORD_TRIE = new_trie(
key.upper()
- for key, value in {
+ for key in {
**klass.KEYWORDS,
**{comment: TokenType.COMMENT for comment in klass._COMMENTS},
**{quote: TokenType.QUOTE for quote in klass._QUOTES},
**{bit_string: TokenType.BIT_STRING for bit_string in klass._BIT_STRINGS},
**{hex_string: TokenType.HEX_STRING for hex_string in klass._HEX_STRINGS},
**{byte_string: TokenType.BYTE_STRING for byte_string in klass._BYTE_STRINGS},
- }.items()
+ }
if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
)
return klass
@staticmethod
- def _delimeter_list_to_dict(list):
+ def _delimeter_list_to_dict(list: t.List[str | t.Tuple[str, str]]) -> t.Dict[str, str]:
return dict((item, item) if isinstance(item, str) else (item[0], item[1]) for item in list)
@@ -375,26 +397,26 @@ class Tokenizer(metaclass=_Tokenizer):
"*": TokenType.STAR,
"~": TokenType.TILDA,
"?": TokenType.PLACEHOLDER,
- "#": TokenType.ANNOTATION,
"@": TokenType.PARAMETER,
# used for breaking a var like x'y' but nothing else
# the token type doesn't matter
"'": TokenType.QUOTE,
"`": TokenType.IDENTIFIER,
'"': TokenType.IDENTIFIER,
+ "#": TokenType.HASH,
}
- QUOTES = ["'"]
+ QUOTES: t.List[t.Tuple[str, str] | str] = ["'"]
- BIT_STRINGS = []
+ BIT_STRINGS: t.List[str | t.Tuple[str, str]] = []
- HEX_STRINGS = []
+ HEX_STRINGS: t.List[str | t.Tuple[str, str]] = []
- BYTE_STRINGS = []
+ BYTE_STRINGS: t.List[str | t.Tuple[str, str]] = []
- IDENTIFIERS = ['"']
+ IDENTIFIERS: t.List[str | t.Tuple[str, str]] = ['"']
- ESCAPE = "'"
+ ESCAPES = ["'"]
KEYWORDS = {
"/*+": TokenType.HINT,
@@ -406,8 +428,10 @@ class Tokenizer(metaclass=_Tokenizer):
"<=": TokenType.LTE,
"<>": TokenType.NEQ,
"!=": TokenType.NEQ,
+ "<=>": TokenType.NULLSAFE_EQ,
"->": TokenType.ARROW,
"->>": TokenType.DARROW,
+ "=>": TokenType.FARROW,
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
@@ -454,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer):
"DESCRIBE": TokenType.DESCRIBE,
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
+ "DISTINCT FROM": TokenType.DISTINCT_FROM,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
@@ -543,6 +568,7 @@ class Tokenizer(metaclass=_Tokenizer):
"RETURNS": TokenType.RETURNS,
"RIGHT": TokenType.RIGHT,
"RLIKE": TokenType.RLIKE,
+ "ROLLBACK": TokenType.ROLLBACK,
"ROLLUP": TokenType.ROLLUP,
"ROW": TokenType.ROW,
"ROWS": TokenType.ROWS,
@@ -622,8 +648,9 @@ class Tokenizer(metaclass=_Tokenizer):
"TEXT": TokenType.TEXT,
"CLOB": TokenType.TEXT,
"BINARY": TokenType.BINARY,
- "BLOB": TokenType.BINARY,
- "BYTEA": TokenType.BINARY,
+ "BLOB": TokenType.VARBINARY,
+ "BYTEA": TokenType.VARBINARY,
+ "VARBINARY": TokenType.VARBINARY,
"TIMESTAMP": TokenType.TIMESTAMP,
"TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
"TIMESTAMPLTZ": TokenType.TIMESTAMPLTZ,
@@ -655,13 +682,13 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.SET,
TokenType.SHOW,
TokenType.TRUNCATE,
- TokenType.USE,
TokenType.VACUUM,
+ TokenType.ROLLBACK,
}
# handle numeric literals like in hive (3L = BIGINT)
- NUMERIC_LITERALS = {}
- ENCODE = None
+ NUMERIC_LITERALS: t.Dict[str, str] = {}
+ ENCODE: t.Optional[str] = None
COMMENTS = ["--", ("/*", "*/")]
KEYWORD_TRIE = None # autofilled
@@ -674,33 +701,39 @@ class Tokenizer(metaclass=_Tokenizer):
"_current",
"_line",
"_col",
+ "_comment",
"_char",
"_end",
"_peek",
+ "_prev_token_line",
+ "_prev_token_comment",
"_prev_token_type",
+ "_replace_backslash",
)
- def __init__(self):
- """
- Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
- """
+ def __init__(self) -> None:
+ self._replace_backslash = "\\" in self._ESCAPES # type: ignore
self.reset()
- def reset(self):
+ def reset(self) -> None:
self.sql = ""
self.size = 0
- self.tokens = []
+ self.tokens: t.List[Token] = []
self._start = 0
self._current = 0
self._line = 1
self._col = 1
+ self._comment = None
self._char = None
self._end = None
self._peek = None
+ self._prev_token_line = -1
+ self._prev_token_comment = None
self._prev_token_type = None
- def tokenize(self, sql):
+ def tokenize(self, sql: str) -> t.List[Token]:
+ """Returns a list of tokens corresponding to the SQL string `sql`."""
self.reset()
self.sql = sql
self.size = len(sql)
@@ -712,14 +745,14 @@ class Tokenizer(metaclass=_Tokenizer):
if not self._char:
break
- white_space = self.WHITE_SPACE.get(self._char)
- identifier_end = self._IDENTIFIERS.get(self._char)
+ white_space = self.WHITE_SPACE.get(self._char) # type: ignore
+ identifier_end = self._IDENTIFIERS.get(self._char) # type: ignore
if white_space:
if white_space == TokenType.BREAK:
self._col = 1
self._line += 1
- elif self._char.isdigit():
+ elif self._char.isdigit(): # type:ignore
self._scan_number()
elif identifier_end:
self._scan_identifier(identifier_end)
@@ -727,38 +760,51 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_keywords()
return self.tokens
- def _chars(self, size):
+ def _chars(self, size: int) -> str:
if size == 1:
- return self._char
+ return self._char # type: ignore
start = self._current - 1
end = start + size
if end <= self.size:
return self.sql[start:end]
return ""
- def _advance(self, i=1):
+ def _advance(self, i: int = 1) -> None:
self._col += i
self._current += i
- self._end = self._current >= self.size
- self._char = self.sql[self._current - 1]
- self._peek = self.sql[self._current] if self._current < self.size else ""
+ self._end = self._current >= self.size # type: ignore
+ self._char = self.sql[self._current - 1] # type: ignore
+ self._peek = self.sql[self._current] if self._current < self.size else "" # type: ignore
@property
- def _text(self):
+ def _text(self) -> str:
return self.sql[self._start : self._current]
- def _add(self, token_type, text=None):
- self._prev_token_type = token_type
- self.tokens.append(Token(token_type, self._text if text is None else text, self._line, self._col))
+ def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None:
+ self._prev_token_line = self._line
+ self._prev_token_comment = self._comment
+ self._prev_token_type = token_type # type: ignore
+ self.tokens.append(
+ Token(
+ token_type,
+ self._text if text is None else text,
+ self._line,
+ self._col,
+ self._comment,
+ )
+ )
+ self._comment = None
- if token_type in self.COMMANDS and (len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON):
+ if token_type in self.COMMANDS and (
+ len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
+ ):
self._start = self._current
while not self._end and self._peek != ";":
self._advance()
if self._start < self._current:
self._add(TokenType.STRING)
- def _scan_keywords(self):
+ def _scan_keywords(self) -> None:
size = 0
word = None
chars = self._text
@@ -771,7 +817,7 @@ class Tokenizer(metaclass=_Tokenizer):
if skip:
result = 1
else:
- result, trie = in_trie(trie, char.upper())
+ result, trie = in_trie(trie, char.upper()) # type: ignore
if result == 0:
break
@@ -793,15 +839,11 @@ class Tokenizer(metaclass=_Tokenizer):
else:
skip = True
else:
- chars = None
+ chars = None # type: ignore
if not word:
if self._char in self.SINGLE_TOKENS:
- token = self.SINGLE_TOKENS[self._char]
- if token == TokenType.ANNOTATION:
- self._scan_annotation()
- return
- self._add(token)
+ self._add(self.SINGLE_TOKENS[self._char]) # type: ignore
return
self._scan_var()
return
@@ -816,31 +858,41 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance(size - 1)
self._add(self.KEYWORDS[word.upper()])
- def _scan_comment(self, comment_start):
- if comment_start not in self._COMMENTS:
+ def _scan_comment(self, comment_start: str) -> bool:
+ if comment_start not in self._COMMENTS: # type: ignore
return False
- comment_end = self._COMMENTS[comment_start]
+ comment_start_line = self._line
+ comment_start_size = len(comment_start)
+ comment_end = self._COMMENTS[comment_start] # type: ignore
if comment_end:
comment_end_size = len(comment_end)
while not self._end and self._chars(comment_end_size) != comment_end:
self._advance()
+
+ self._comment = self._text[comment_start_size : -comment_end_size + 1] # type: ignore
self._advance(comment_end_size - 1)
else:
- while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
+ while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: # type: ignore
self._advance()
- return True
+ self._comment = self._text[comment_start_size:] # type: ignore
- def _scan_annotation(self):
- while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK and self._peek != ",":
- self._advance()
- self._add(TokenType.ANNOTATION, self._text[1:])
+ # Leading comment is attached to the succeeding token, whilst trailing comment to the preceding. If both
+ # types of comment can be attached to a token, the trailing one is discarded in favour of the leading one.
- def _scan_number(self):
+ if comment_start_line == self._prev_token_line:
+ if self._prev_token_comment is None:
+ self.tokens[-1].comment = self._comment
+
+ self._comment = None
+
+ return True
+
+ def _scan_number(self) -> None:
if self._char == "0":
- peek = self._peek.upper()
+ peek = self._peek.upper() # type: ignore
if peek == "B":
return self._scan_bits()
elif peek == "X":
@@ -850,7 +902,7 @@ class Tokenizer(metaclass=_Tokenizer):
scientific = 0
while True:
- if self._peek.isdigit():
+ if self._peek.isdigit(): # type: ignore
self._advance()
elif self._peek == "." and not decimal:
decimal = True
@@ -858,25 +910,25 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek in ("-", "+") and scientific == 1:
scientific += 1
self._advance()
- elif self._peek.upper() == "E" and not scientific:
+ elif self._peek.upper() == "E" and not scientific: # type: ignore
scientific += 1
self._advance()
- elif self._peek.isalpha():
+ elif self._peek.isalpha(): # type: ignore
self._add(TokenType.NUMBER)
literal = []
- while self._peek.isalpha():
- literal.append(self._peek.upper())
+ while self._peek.isalpha(): # type: ignore
+ literal.append(self._peek.upper()) # type: ignore
self._advance()
- literal = "".join(literal)
- token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
+ literal = "".join(literal) # type: ignore
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) # type: ignore
if token_type:
self._add(TokenType.DCOLON, "::")
- return self._add(token_type, literal)
+ return self._add(token_type, literal) # type: ignore
return self._advance(-len(literal))
else:
return self._add(TokenType.NUMBER)
- def _scan_bits(self):
+ def _scan_bits(self) -> None:
self._advance()
value = self._extract_value()
try:
@@ -884,7 +936,7 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError:
self._add(TokenType.IDENTIFIER)
- def _scan_hex(self):
+ def _scan_hex(self) -> None:
self._advance()
value = self._extract_value()
try:
@@ -892,9 +944,9 @@ class Tokenizer(metaclass=_Tokenizer):
except ValueError:
self._add(TokenType.IDENTIFIER)
- def _extract_value(self):
+ def _extract_value(self) -> str:
while True:
- char = self._peek.strip()
+ char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@@ -902,31 +954,30 @@ class Tokenizer(metaclass=_Tokenizer):
return self._text
- def _scan_string(self, quote):
- quote_end = self._QUOTES.get(quote)
+ def _scan_string(self, quote: str) -> bool:
+ quote_end = self._QUOTES.get(quote) # type: ignore
if quote_end is None:
return False
self._advance(len(quote))
text = self._extract_string(quote_end)
-
- text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
- text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
+ text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore
+ text = text.replace("\\\\", "\\") if self._replace_backslash else text
self._add(TokenType.STRING, text)
return True
# X'1234, b'0110', E'\\\\\' etc.
- def _scan_formatted_string(self, string_start):
- if string_start in self._HEX_STRINGS:
- delimiters = self._HEX_STRINGS
+ def _scan_formatted_string(self, string_start: str) -> bool:
+ if string_start in self._HEX_STRINGS: # type: ignore
+ delimiters = self._HEX_STRINGS # type: ignore
token_type = TokenType.HEX_STRING
base = 16
- elif string_start in self._BIT_STRINGS:
- delimiters = self._BIT_STRINGS
+ elif string_start in self._BIT_STRINGS: # type: ignore
+ delimiters = self._BIT_STRINGS # type: ignore
token_type = TokenType.BIT_STRING
base = 2
- elif string_start in self._BYTE_STRINGS:
- delimiters = self._BYTE_STRINGS
+ elif string_start in self._BYTE_STRINGS: # type: ignore
+ delimiters = self._BYTE_STRINGS # type: ignore
token_type = TokenType.BYTE_STRING
base = None
else:
@@ -942,11 +993,13 @@ class Tokenizer(metaclass=_Tokenizer):
try:
self._add(token_type, f"{int(text, base)}")
except:
- raise RuntimeError(f"Numeric string contains invalid characters from {self._line}:{self._start}")
+ raise RuntimeError(
+ f"Numeric string contains invalid characters from {self._line}:{self._start}"
+ )
return True
- def _scan_identifier(self, identifier_end):
+ def _scan_identifier(self, identifier_end: str) -> None:
while self._peek != identifier_end:
if self._end:
raise RuntimeError(f"Missing {identifier_end} from {self._line}:{self._start}")
@@ -954,9 +1007,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._advance()
self._add(TokenType.IDENTIFIER, self._text[1:-1])
- def _scan_var(self):
+ def _scan_var(self) -> None:
while True:
- char = self._peek.strip()
+ char = self._peek.strip() # type: ignore
if char and char not in self.SINGLE_TOKENS:
self._advance()
else:
@@ -967,12 +1020,12 @@ class Tokenizer(metaclass=_Tokenizer):
else self.KEYWORDS.get(self._text.upper(), TokenType.VAR)
)
- def _extract_string(self, delimiter):
+ def _extract_string(self, delimiter: str) -> str:
text = ""
delim_size = len(delimiter)
while True:
- if self._char == self.ESCAPE and self._peek == delimiter:
+ if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
text += delimiter
self._advance(2)
else:
@@ -983,7 +1036,7 @@ class Tokenizer(metaclass=_Tokenizer):
if self._end:
raise RuntimeError(f"Missing {delimiter} from {self._line}:{self._start}")
- text += self._char
+ text += self._char # type: ignore
self._advance()
return text
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 014ae00..412b881 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -1,7 +1,14 @@
+from __future__ import annotations
+
+import typing as t
+
+if t.TYPE_CHECKING:
+ from sqlglot.generator import Generator
+
from sqlglot import expressions as exp
-def unalias_group(expression):
+def unalias_group(expression: exp.Expression) -> exp.Expression:
"""
Replace references to select aliases in GROUP BY clauses.
@@ -9,6 +16,12 @@ def unalias_group(expression):
>>> import sqlglot
>>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
'SELECT a AS b FROM x GROUP BY 1'
+
+ Args:
+ expression: the expression that will be transformed.
+
+ Returns:
+ The transformed expression.
"""
if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
aliased_selects = {
@@ -30,19 +43,20 @@ def unalias_group(expression):
return expression
-def preprocess(transforms, to_sql):
+def preprocess(
+ transforms: t.List[t.Callable[[exp.Expression], exp.Expression]],
+ to_sql: t.Callable[[Generator, exp.Expression], str],
+) -> t.Callable[[Generator, exp.Expression], str]:
"""
- Create a new transform function that can be used a value in `Generator.TRANSFORMS`
- to convert expressions to SQL.
+ Creates a new transform by chaining a sequence of transformations and converts the resulting
+ expression to SQL, using an appropriate `Generator.TRANSFORMS` function.
Args:
- transforms (list[(exp.Expression) -> exp.Expression]):
- Sequence of transform functions. These will be called in order.
- to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
- Final transform that converts the resulting expression to a SQL string.
+ transforms: sequence of transform functions. These will be called in order.
+ to_sql: final transform that converts the resulting expression to a SQL string.
+
Returns:
- (sqlglot.generator.Generator, exp.Expression) -> str:
- Function that can be used as a generator transform.
+ Function that can be used as a generator transform.
"""
def _to_sql(self, expression):
@@ -54,12 +68,10 @@ def preprocess(transforms, to_sql):
return _to_sql
-def delegate(attr):
+def delegate(attr: str) -> t.Callable:
"""
- Create a new method that delegates to `attr`.
-
- This is useful for creating `Generator.TRANSFORMS` functions that delegate
- to existing generator methods.
+ Create a new method that delegates to `attr`. This is useful for creating `Generator.TRANSFORMS`
+ functions that delegate to existing generator methods.
"""
def _transform(self, *args, **kwargs):
diff --git a/sqlglot/trie.py b/sqlglot/trie.py
index a234107..fa2aaf1 100644
--- a/sqlglot/trie.py
+++ b/sqlglot/trie.py
@@ -1,5 +1,26 @@
-def new_trie(keywords):
- trie = {}
+import typing as t
+
+key = t.Sequence[t.Hashable]
+
+
+def new_trie(keywords: t.Iterable[key]) -> t.Dict:
+ """
+ Creates a new trie out of a collection of keywords.
+
+ The trie is represented as a sequence of nested dictionaries keyed by either single character
+ strings, or by 0, which is used to designate that a keyword is in the trie.
+
+ Example:
+ >>> new_trie(["bla", "foo", "blab"])
+ {'b': {'l': {'a': {0: True, 'b': {0: True}}}}, 'f': {'o': {'o': {0: True}}}}
+
+ Args:
+ keywords: the keywords to create the trie from.
+
+ Returns:
+ The trie corresponding to `keywords`.
+ """
+ trie: t.Dict = {}
for key in keywords:
current = trie
@@ -11,7 +32,28 @@ def new_trie(keywords):
return trie
-def in_trie(trie, key):
+def in_trie(trie: t.Dict, key: key) -> t.Tuple[int, t.Dict]:
+ """
+ Checks whether a key is in a trie.
+
+ Examples:
+ >>> in_trie(new_trie(["cat"]), "bob")
+ (0, {'c': {'a': {'t': {0: True}}}})
+
+ >>> in_trie(new_trie(["cat"]), "ca")
+ (1, {'t': {0: True}})
+
+ >>> in_trie(new_trie(["cat"]), "cat")
+ (2, {0: True})
+
+ Args:
+ trie: the trie to be searched.
+ key: the target key.
+
+ Returns:
+ A pair `(value, subtrie)`, where `subtrie` is the sub-trie we get at the point where the search stops, and `value`
+ is either 0 (search was unsuccessfull), 1 (`value` is a prefix of a keyword in `trie`) or 2 (`key is in `trie`).
+ """
if not key:
return (0, trie)