summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-19 14:50:39 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-11-19 14:50:39 +0000
commitf2981e8e4d28233864f1ca06ecec45ab80bf9eae (patch)
treeb70cb633916830138ce3424aa361f0bbaff02be2 /sqlglot
parentReleasing debian version 10.0.1-1. (diff)
downloadsqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.tar.xz
sqlglot-f2981e8e4d28233864f1ca06ecec45ab80bf9eae.zip
Merging upstream version 10.0.8.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dataframe/sql/column.py5
-rw-r--r--sqlglot/dataframe/sql/dataframe.py8
-rw-r--r--sqlglot/dataframe/sql/functions.py18
-rw-r--r--sqlglot/dataframe/sql/session.py8
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py11
-rw-r--r--sqlglot/dialects/dialect.py16
-rw-r--r--sqlglot/dialects/drill.py174
-rw-r--r--sqlglot/dialects/duckdb.py4
-rw-r--r--sqlglot/dialects/hive.py20
-rw-r--r--sqlglot/dialects/mysql.py76
-rw-r--r--sqlglot/dialects/oracle.py1
-rw-r--r--sqlglot/dialects/postgres.py38
-rw-r--r--sqlglot/dialects/presto.py38
-rw-r--r--sqlglot/dialects/snowflake.py2
-rw-r--r--sqlglot/dialects/sqlite.py5
-rw-r--r--sqlglot/dialects/tsql.py2
-rw-r--r--sqlglot/diff.py55
-rw-r--r--sqlglot/errors.py4
-rw-r--r--sqlglot/executor/__init__.py23
-rw-r--r--sqlglot/executor/context.py47
-rw-r--r--sqlglot/executor/env.py162
-rw-r--r--sqlglot/executor/python.py287
-rw-r--r--sqlglot/executor/table.py43
-rw-r--r--sqlglot/expressions.py128
-rw-r--r--sqlglot/generator.py42
-rw-r--r--sqlglot/helper.py45
-rw-r--r--sqlglot/optimizer/annotate_types.py26
-rw-r--r--sqlglot/optimizer/canonicalize.py48
-rw-r--r--sqlglot/optimizer/eliminate_joins.py13
-rw-r--r--sqlglot/optimizer/optimize_joins.py4
-rw-r--r--sqlglot/optimizer/optimizer.py4
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py4
-rw-r--r--sqlglot/optimizer/qualify_columns.py4
-rw-r--r--sqlglot/optimizer/qualify_tables.py14
-rw-r--r--sqlglot/optimizer/simplify.py6
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py2
-rw-r--r--sqlglot/parser.py403
-rw-r--r--sqlglot/planner.py227
-rw-r--r--sqlglot/schema.py215
-rw-r--r--sqlglot/tokens.py58
42 files changed, 1587 insertions, 706 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 6e67b19..50e2d9c 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -30,7 +30,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.0.1"
+__version__ = "10.0.8"
pretty = False
diff --git a/sqlglot/dataframe/sql/column.py b/sqlglot/dataframe/sql/column.py
index f9e1c5b..22075e9 100644
--- a/sqlglot/dataframe/sql/column.py
+++ b/sqlglot/dataframe/sql/column.py
@@ -260,7 +260,10 @@ class Column:
"""
if isinstance(dataType, DataType):
dataType = dataType.simpleString()
- new_expression = exp.Cast(this=self.column_expression, to=dataType)
+ new_expression = exp.Cast(
+ this=self.column_expression,
+ to=sqlglot.parse_one(dataType, into=exp.DataType, read="spark"), # type: ignore
+ )
return Column(new_expression)
def startswith(self, value: t.Union[str, Column]) -> Column:
diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py
index 40cd6c9..548c322 100644
--- a/sqlglot/dataframe/sql/dataframe.py
+++ b/sqlglot/dataframe/sql/dataframe.py
@@ -314,7 +314,13 @@ class DataFrame:
replacement_mapping[exp.to_identifier(original_alias_name)] = exp.to_identifier( # type: ignore
cache_table_name
)
- sqlglot.schema.add_table(cache_table_name, select_expression.named_selects)
+ sqlglot.schema.add_table(
+ cache_table_name,
+ {
+ expression.alias_or_name: expression.type.name
+ for expression in select_expression.expressions
+ },
+ )
cache_storage_level = select_expression.args["cache_storage_level"]
options = [
exp.Literal.string("storageLevel"),
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index dbfb06f..1ee361a 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -757,11 +757,15 @@ def concat_ws(sep: str, *cols: ColumnOrName) -> Column:
def decode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_anonymous_function(col, "DECODE", lit(charset))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Decode, charset=glotexp.Literal.string(charset)
+ )
def encode(col: ColumnOrName, charset: str) -> Column:
- return Column.invoke_anonymous_function(col, "ENCODE", lit(charset))
+ return Column.invoke_expression_over_column(
+ col, glotexp.Encode, charset=glotexp.Literal.string(charset)
+ )
def format_number(col: ColumnOrName, d: int) -> Column:
@@ -867,11 +871,11 @@ def bin(col: ColumnOrName) -> Column:
def hex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "HEX")
+ return Column.invoke_expression_over_column(col, glotexp.Hex)
def unhex(col: ColumnOrName) -> Column:
- return Column.invoke_anonymous_function(col, "UNHEX")
+ return Column.invoke_expression_over_column(col, glotexp.Unhex)
def length(col: ColumnOrName) -> Column:
@@ -939,11 +943,7 @@ def array_join(
def concat(*cols: ColumnOrName) -> Column:
- if len(cols) == 1:
- return Column.invoke_anonymous_function(cols[0], "CONCAT")
- return Column.invoke_anonymous_function(
- cols[0], "CONCAT", *[Column.ensure_col(x).expression for x in cols[1:]]
- )
+ return Column.invoke_expression_over_column(None, glotexp.Concat, expressions=cols)
def array_position(col: ColumnOrName, value: ColumnOrLiteral) -> Column:
diff --git a/sqlglot/dataframe/sql/session.py b/sqlglot/dataframe/sql/session.py
index 8cb16ef..c4a22c6 100644
--- a/sqlglot/dataframe/sql/session.py
+++ b/sqlglot/dataframe/sql/session.py
@@ -88,14 +88,14 @@ class SparkSession:
"expressions": sel_columns,
"from": exp.From(
expressions=[
- exp.Subquery(
- this=exp.Values(expressions=data_expressions),
+ exp.Values(
+ expressions=data_expressions,
alias=exp.TableAlias(
this=exp.to_identifier(self._auto_incrementing_name),
columns=[exp.to_identifier(col_name) for col_name in column_mapping],
),
- )
- ]
+ ),
+ ],
),
}
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 0816831..2e42e7d 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -2,6 +2,7 @@ from sqlglot.dialects.bigquery import BigQuery
from sqlglot.dialects.clickhouse import ClickHouse
from sqlglot.dialects.databricks import Databricks
from sqlglot.dialects.dialect import Dialect, Dialects
+from sqlglot.dialects.drill import Drill
from sqlglot.dialects.duckdb import DuckDB
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 5bbff9d..4550d65 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -119,6 +119,8 @@ class BigQuery(Dialect):
"UNKNOWN": TokenType.NULL,
"WINDOW": TokenType.WINDOW,
"NOT DETERMINISTIC": TokenType.VOLATILE,
+ "BEGIN": TokenType.COMMAND,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
}
KEYWORDS.pop("DIV")
@@ -204,6 +206,15 @@ class BigQuery(Dialect):
EXPLICIT_UNION = True
+ def transaction_sql(self, *_):
+ return "BEGIN TRANSACTION"
+
+ def commit_sql(self, *_):
+ return "COMMIT TRANSACTION"
+
+ def rollback_sql(self, *_):
+ return "ROLLBACK TRANSACTION"
+
def in_unnest_op(self, unnest):
return self.sql(unnest)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 3af08bb..8c497ab 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -32,6 +32,7 @@ class Dialects(str, Enum):
TRINO = "trino"
TSQL = "tsql"
DATABRICKS = "databricks"
+ DRILL = "drill"
class _Dialect(type):
@@ -362,3 +363,18 @@ def parse_date_delta(exp_class, unit_mapping=None):
return exp_class(this=this, expression=expression, unit=unit)
return inner_func
+
+
+def locate_to_strposition(args):
+ return exp.StrPosition(
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ )
+
+
+def strposition_to_local_sql(self, expression):
+ args = self.format_args(
+ expression.args.get("substr"), expression.this, expression.args.get("position")
+ )
+ return f"LOCATE({args})"
diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py
new file mode 100644
index 0000000..eb420aa
--- /dev/null
+++ b/sqlglot/dialects/drill.py
@@ -0,0 +1,174 @@
+from __future__ import annotations
+
+import re
+
+from sqlglot import exp, generator, parser, tokens
+from sqlglot.dialects.dialect import (
+ Dialect,
+ create_with_partitions_sql,
+ format_time_lambda,
+ no_pivot_sql,
+ no_trycast_sql,
+ rename_func,
+ str_position_sql,
+)
+from sqlglot.dialects.postgres import _lateral_sql
+
+
+def _to_timestamp(args):
+ # TO_TIMESTAMP accepts either a single double argument or (text, text)
+ if len(args) == 1 and args[0].is_number:
+ return exp.UnixToTime.from_arg_list(args)
+ return format_time_lambda(exp.StrToTime, "drill")(args)
+
+
+def _str_to_time_sql(self, expression):
+ return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
+
+
+def _ts_or_ds_to_date_sql(self, expression):
+ time_format = self.format_time(expression)
+ if time_format and time_format not in (Drill.time_format, Drill.date_format):
+ return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
+ return f"CAST({self.sql(expression, 'this')} AS DATE)"
+
+
+def _date_add_sql(kind):
+ def func(self, expression):
+ this = self.sql(expression, "this")
+ unit = expression.text("unit").upper() or "DAY"
+ expression = self.sql(expression, "expression")
+ return f"DATE_{kind}({this}, INTERVAL '{expression}' {unit})"
+
+ return func
+
+
+def if_sql(self, expression):
+ """
+ Drill requires backticks around certain SQL reserved words, IF being one of them, This function
+ adds the backticks around the keyword IF.
+ Args:
+ self: The Drill dialect
+ expression: The input IF expression
+
+ Returns: The expression with IF in backticks.
+
+ """
+ expressions = self.format_args(
+ expression.this, expression.args.get("true"), expression.args.get("false")
+ )
+ return f"`IF`({expressions})"
+
+
+def _str_to_date(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format == Drill.date_format:
+ return f"CAST({this} AS DATE)"
+ return f"TO_DATE({this}, {time_format})"
+
+
+class Drill(Dialect):
+ normalize_functions = None
+ null_ordering = "nulls_are_last"
+ date_format = "'yyyy-MM-dd'"
+ dateint_format = "'yyyyMMdd'"
+ time_format = "'yyyy-MM-dd HH:mm:ss'"
+
+ time_mapping = {
+ "y": "%Y",
+ "Y": "%Y",
+ "YYYY": "%Y",
+ "yyyy": "%Y",
+ "YY": "%y",
+ "yy": "%y",
+ "MMMM": "%B",
+ "MMM": "%b",
+ "MM": "%m",
+ "M": "%-m",
+ "dd": "%d",
+ "d": "%-d",
+ "HH": "%H",
+ "H": "%-H",
+ "hh": "%I",
+ "h": "%-I",
+ "mm": "%M",
+ "m": "%-M",
+ "ss": "%S",
+ "s": "%-S",
+ "SSSSSS": "%f",
+ "a": "%p",
+ "DD": "%j",
+ "D": "%-j",
+ "E": "%a",
+ "EE": "%a",
+ "EEE": "%a",
+ "EEEE": "%A",
+ "''T''": "T",
+ }
+
+ class Tokenizer(tokens.Tokenizer):
+ QUOTES = ["'"]
+ IDENTIFIERS = ["`"]
+ ESCAPES = ["\\"]
+ ENCODE = "utf-8"
+
+ class Parser(parser.Parser):
+ STRICT_CAST = False
+
+ FUNCTIONS = {
+ **parser.Parser.FUNCTIONS,
+ "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
+ "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"),
+ }
+
+ class Generator(generator.Generator):
+ TYPE_MAPPING = {
+ **generator.Generator.TYPE_MAPPING,
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.SMALLINT: "INTEGER",
+ exp.DataType.Type.TINYINT: "INTEGER",
+ exp.DataType.Type.BINARY: "VARBINARY",
+ exp.DataType.Type.TEXT: "VARCHAR",
+ exp.DataType.Type.NCHAR: "VARCHAR",
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
+ }
+
+ ROOT_PROPERTIES = {exp.PartitionedByProperty}
+
+ TRANSFORMS = {
+ **generator.Generator.TRANSFORMS,
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.Lateral: _lateral_sql,
+ exp.ArrayContains: rename_func("REPEATED_CONTAINS"),
+ exp.ArraySize: rename_func("REPEATED_COUNT"),
+ exp.Create: create_with_partitions_sql,
+ exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.DateSub: _date_add_sql("SUB"),
+ exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)",
+ exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})",
+ exp.If: if_sql,
+ exp.ILike: lambda self, e: f" {self.sql(e, 'this')} `ILIKE` {self.sql(e, 'expression')}",
+ exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
+ exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.Pivot: no_pivot_sql,
+ exp.RegexpLike: rename_func("REGEXP_MATCHES"),
+ exp.StrPosition: str_position_sql,
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TryCast: no_trycast_sql,
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), INTERVAL '{self.sql(e, 'expression')}' DAY)",
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
+ }
+
+ def normalize_func(self, name):
+ return name if re.match(exp.SAFE_IDENTIFIER_RE, name) else f"`{name}`"
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 781edff..f1da72b 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -55,13 +55,13 @@ def _array_sort_sql(self, expression):
def _sort_array_sql(self, expression):
this = self.sql(expression, "this")
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
return f"ARRAY_REVERSE_SORT({this})"
return f"ARRAY_SORT({this})"
def _sort_array_reverse(args):
- return exp.SortArray(this=seq_get(args, 0), asc=exp.FALSE)
+ return exp.SortArray(this=seq_get(args, 0), asc=exp.false())
def _struct_pack_sql(self, expression):
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index ed7357c..cff7139 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -7,16 +7,19 @@ from sqlglot.dialects.dialect import (
create_with_partitions_sql,
format_time_lambda,
if_sql,
+ locate_to_strposition,
no_ilike_sql,
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
rename_func,
+ strposition_to_local_sql,
struct_extract_sql,
var_map_sql,
)
from sqlglot.helper import seq_get
from sqlglot.parser import parse_var_map
+from sqlglot.tokens import TokenType
# (FuncType, Multiplier)
DATE_DELTA_INTERVAL = {
@@ -181,6 +184,15 @@ class Hive(Dialect):
"F": "FLOAT",
"BD": "DECIMAL",
}
+ KEYWORDS = {
+ **tokens.Tokenizer.KEYWORDS,
+ "ADD ARCHIVE": TokenType.COMMAND,
+ "ADD ARCHIVES": TokenType.COMMAND,
+ "ADD FILE": TokenType.COMMAND,
+ "ADD FILES": TokenType.COMMAND,
+ "ADD JAR": TokenType.COMMAND,
+ "ADD JARS": TokenType.COMMAND,
+ }
class Parser(parser.Parser):
STRICT_CAST = False
@@ -210,11 +222,7 @@ class Hive(Dialect):
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
"GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
- "LOCATE": lambda args: exp.StrPosition(
- this=seq_get(args, 1),
- substr=seq_get(args, 0),
- position=seq_get(args, 2),
- ),
+ "LOCATE": locate_to_strposition,
"LOG": (
lambda args: exp.Log.from_arg_list(args)
if len(args) > 1
@@ -272,7 +280,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
- exp.StrPosition: lambda self, e: f"LOCATE({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
+ exp.StrPosition: strposition_to_local_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index e742640..93a60f4 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -5,10 +5,12 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import (
Dialect,
+ locate_to_strposition,
no_ilike_sql,
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
+ strposition_to_local_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -120,6 +122,7 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
"_ASCII": TokenType.INTRODUCER,
@@ -172,13 +175,18 @@ class MySQL(Dialect):
COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW}
class Parser(parser.Parser):
- STRICT_CAST = False
+ FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA}
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_ADD": _date_add(exp.DateAdd),
"DATE_SUB": _date_add(exp.DateSub),
"STR_TO_DATE": _str_to_date,
+ "LOCATE": locate_to_strposition,
+ "INSTR": lambda args: exp.StrPosition(substr=seq_get(args, 1), this=seq_get(args, 0)),
+ "LEFT": lambda args: exp.Substring(
+ this=seq_get(args, 0), start=exp.Literal.number(1), length=seq_get(args, 1)
+ ),
}
FUNCTION_PARSERS = {
@@ -264,6 +272,7 @@ class MySQL(Dialect):
"CHARACTER SET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_charset("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
+ "TRANSACTION": lambda self: self._parse_set_transaction(),
}
PROFILE_TYPES = {
@@ -278,39 +287,48 @@ class MySQL(Dialect):
"SWAPS",
}
+ TRANSACTION_CHARACTERISTICS = {
+ "ISOLATION LEVEL REPEATABLE READ",
+ "ISOLATION LEVEL READ COMMITTED",
+ "ISOLATION LEVEL READ UNCOMMITTED",
+ "ISOLATION LEVEL SERIALIZABLE",
+ "READ WRITE",
+ "READ ONLY",
+ }
+
def _parse_show_mysql(self, this, target=False, full=None, global_=None):
if target:
if isinstance(target, str):
- self._match_text(target)
+ self._match_text_seq(target)
target_id = self._parse_id_var()
else:
target_id = None
- log = self._parse_string() if self._match_text("IN") else None
+ log = self._parse_string() if self._match_text_seq("IN") else None
if this in {"BINLOG EVENTS", "RELAYLOG EVENTS"}:
- position = self._parse_number() if self._match_text("FROM") else None
+ position = self._parse_number() if self._match_text_seq("FROM") else None
db = None
else:
position = None
- db = self._parse_id_var() if self._match_text("FROM") else None
+ db = self._parse_id_var() if self._match_text_seq("FROM") else None
- channel = self._parse_id_var() if self._match_text("FOR", "CHANNEL") else None
+ channel = self._parse_id_var() if self._match_text_seq("FOR", "CHANNEL") else None
- like = self._parse_string() if self._match_text("LIKE") else None
+ like = self._parse_string() if self._match_text_seq("LIKE") else None
where = self._parse_where()
if this == "PROFILE":
- types = self._parse_csv(self._parse_show_profile_type)
- query = self._parse_number() if self._match_text("FOR", "QUERY") else None
- offset = self._parse_number() if self._match_text("OFFSET") else None
- limit = self._parse_number() if self._match_text("LIMIT") else None
+ types = self._parse_csv(lambda: self._parse_var_from_options(self.PROFILE_TYPES))
+ query = self._parse_number() if self._match_text_seq("FOR", "QUERY") else None
+ offset = self._parse_number() if self._match_text_seq("OFFSET") else None
+ limit = self._parse_number() if self._match_text_seq("LIMIT") else None
else:
types, query = None, None
offset, limit = self._parse_oldstyle_limit()
- mutex = True if self._match_text("MUTEX") else None
- mutex = False if self._match_text("STATUS") else mutex
+ mutex = True if self._match_text_seq("MUTEX") else None
+ mutex = False if self._match_text_seq("STATUS") else mutex
return self.expression(
exp.Show,
@@ -331,16 +349,16 @@ class MySQL(Dialect):
**{"global": global_},
)
- def _parse_show_profile_type(self):
- for type_ in self.PROFILE_TYPES:
- if self._match_text(*type_.split(" ")):
- return exp.Var(this=type_)
+ def _parse_var_from_options(self, options):
+ for option in options:
+ if self._match_text_seq(*option.split(" ")):
+ return exp.Var(this=option)
return None
def _parse_oldstyle_limit(self):
limit = None
offset = None
- if self._match_text("LIMIT"):
+ if self._match_text_seq("LIMIT"):
parts = self._parse_csv(self._parse_number)
if len(parts) == 1:
limit = parts[0]
@@ -353,6 +371,9 @@ class MySQL(Dialect):
return self._parse_set_item_assignment(kind=None)
def _parse_set_item_assignment(self, kind):
+ if kind in {"GLOBAL", "SESSION"} and self._match_text_seq("TRANSACTION"):
+ return self._parse_set_transaction(global_=kind == "GLOBAL")
+
left = self._parse_primary() or self._parse_id_var()
if not self._match(TokenType.EQ):
self.raise_error("Expected =")
@@ -381,7 +402,7 @@ class MySQL(Dialect):
def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
- if self._match_text("COLLATE"):
+ if self._match_text_seq("COLLATE"):
collate = self._parse_string() or self._parse_id_var()
else:
collate = None
@@ -392,6 +413,18 @@ class MySQL(Dialect):
kind="NAMES",
)
+ def _parse_set_transaction(self, global_=False):
+ self._match_text_seq("TRANSACTION")
+ characteristics = self._parse_csv(
+ lambda: self._parse_var_from_options(self.TRANSACTION_CHARACTERISTICS)
+ )
+ return self.expression(
+ exp.SetItem,
+ expressions=characteristics,
+ kind="TRANSACTION",
+ **{"global": global_},
+ )
+
class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False
@@ -411,6 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
+ exp.StrPosition: strposition_to_local_sql,
}
ROOT_PROPERTIES = {
@@ -481,9 +515,11 @@ class MySQL(Dialect):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
+ expressions = self.expressions(expression)
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
- return f"{kind}{this}{collate}"
+ global_ = "GLOBAL " if expression.args.get("global") else ""
+ return f"{global_}{kind}{this}{expressions}{collate}"
def set_sql(self, expression):
return f"SET {self.expressions(expression)}"
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
index 3bc1109..870d2b9 100644
--- a/sqlglot/dialects/oracle.py
+++ b/sqlglot/dialects/oracle.py
@@ -91,6 +91,7 @@ class Oracle(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"TOP": TokenType.TOP,
"VARCHAR2": TokenType.VARCHAR,
"NVARCHAR2": TokenType.NVARCHAR,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 553a73b..4353164 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -164,11 +164,34 @@ class Postgres(Dialect):
BIT_STRINGS = [("b'", "'"), ("B'", "'")]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]
BYTE_STRINGS = [("e'", "'"), ("E'", "'")]
+
+ CREATABLES = (
+ "AGGREGATE",
+ "CAST",
+ "CONVERSION",
+ "COLLATION",
+ "DEFAULT CONVERSION",
+ "CONSTRAINT",
+ "DOMAIN",
+ "EXTENSION",
+ "FOREIGN",
+ "FUNCTION",
+ "OPERATOR",
+ "POLICY",
+ "ROLE",
+ "RULE",
+ "SEQUENCE",
+ "TEXT",
+ "TRIGGER",
+ "TYPE",
+ "UNLOGGED",
+ "USER",
+ )
+
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"ALWAYS": TokenType.ALWAYS,
"BY DEFAULT": TokenType.BY_DEFAULT,
- "COMMENT ON": TokenType.COMMENT_ON,
"IDENTITY": TokenType.IDENTITY,
"GENERATED": TokenType.GENERATED,
"DOUBLE PRECISION": TokenType.DOUBLE,
@@ -176,6 +199,19 @@ class Postgres(Dialect):
"SERIAL": TokenType.SERIAL,
"SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
+ "TEMP": TokenType.TEMPORARY,
+ "BEGIN TRANSACTION": TokenType.BEGIN,
+ "BEGIN": TokenType.COMMAND,
+ "COMMENT ON": TokenType.COMMAND,
+ "DECLARE": TokenType.COMMAND,
+ "DO": TokenType.COMMAND,
+ "REFRESH": TokenType.COMMAND,
+ "REINDEX": TokenType.COMMAND,
+ "RESET": TokenType.COMMAND,
+ "REVOKE": TokenType.COMMAND,
+ "GRANT": TokenType.COMMAND,
+ **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
+ **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
QUOTES = ["'", "$$"]
SINGLE_TOKENS = {
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index 11ea778..9d5cc11 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import (
struct_extract_sql,
)
from sqlglot.dialects.mysql import MySQL
+from sqlglot.errors import UnsupportedError
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -61,8 +62,18 @@ def _initcap_sql(self, expression):
return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
+def _decode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"FROM_UTF8({self.sql(expression, 'this')})"
+
+
+def _encode_sql(self, expression):
+ _ensure_utf8(expression.args.get("charset"))
+ return f"TO_UTF8({self.sql(expression, 'this')})"
+
+
def _no_sort_array(self, expression):
- if expression.args.get("asc") == exp.FALSE:
+ if expression.args.get("asc") == exp.false():
comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
else:
comparator = None
@@ -72,7 +83,7 @@ def _no_sort_array(self, expression):
def _schema_sql(self, expression):
if isinstance(expression.parent, exp.Property):
- columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
+ columns = ", ".join(f"'{c.name}'" for c in expression.expressions)
return f"ARRAY[{columns}]"
for schema in expression.parent.find_all(exp.Schema):
@@ -106,6 +117,11 @@ def _ts_or_ds_add_sql(self, expression):
return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+def _ensure_utf8(charset):
+ if charset.name.lower() != "utf-8":
+ raise UnsupportedError(f"Unsupported charset {charset}")
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@@ -115,6 +131,7 @@ class Presto(Dialect):
class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "START": TokenType.BEGIN,
"ROW": TokenType.STRUCT,
}
@@ -140,6 +157,14 @@ class Presto(Dialect):
"STRPOS": exp.StrPosition.from_arg_list,
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "FROM_HEX": exp.Unhex.from_arg_list,
+ "TO_HEX": exp.Hex.from_arg_list,
+ "TO_UTF8": lambda args: exp.Encode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
+ "FROM_UTF8": lambda args: exp.Decode(
+ this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ ),
}
class Generator(generator.Generator):
@@ -187,7 +212,10 @@ class Presto(Dialect):
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
+ exp.Decode: _decode_sql,
exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
+ exp.Encode: _encode_sql,
+ exp.Hex: rename_func("TO_HEX"),
exp.If: if_sql,
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
@@ -212,7 +240,13 @@ class Presto(Dialect):
exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.TsOrDsAdd: _ts_or_ds_add_sql,
exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.Unhex: rename_func("FROM_HEX"),
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
}
+
+ def transaction_sql(self, expression):
+ modes = expression.args.get("modes")
+ modes = f" {', '.join(modes)}" if modes else ""
+ return f"START TRANSACTION{modes}"
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index d1aaded..a96bd80 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -148,6 +148,7 @@ class Snowflake(Dialect):
**parser.Parser.FUNCTION_PARSERS,
"DATE_PART": _parse_date_part,
}
+ FUNCTION_PARSERS.pop("TRIM")
FUNC_TOKENS = {
*parser.Parser.FUNC_TOKENS,
@@ -203,6 +204,7 @@ class Snowflake(Dialect):
exp.StrPosition: rename_func("POSITION"),
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'value')}",
+ exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})",
}
TYPE_MAPPING = {
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index 8c9fb76..87b98a5 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -63,3 +63,8 @@ class SQLite(Dialect):
exp.TableSample: no_tablesample_sql,
exp.TryCast: no_trycast_sql,
}
+
+ def transaction_sql(self, expression):
+ this = expression.this
+ this = f" {this}" if this else ""
+ return f"BEGIN{this} TRANSACTION"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index a233d4b..d3b83de 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -248,7 +248,7 @@ class TSQL(Dialect):
def _parse_convert(self, strict):
to = self._parse_types()
self._match(TokenType.COMMA)
- this = self._parse_column()
+ this = self._parse_conjunction()
# Retrieve length of datatype and override to default if not specified
if seq_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES:
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
index 2d959ab..758ad1b 100644
--- a/sqlglot/diff.py
+++ b/sqlglot/diff.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import typing as t
from collections import defaultdict
from dataclasses import dataclass
from heapq import heappop, heappush
@@ -6,6 +9,10 @@ from sqlglot import Dialect
from sqlglot import expressions as exp
from sqlglot.helper import ensure_collection
+if t.TYPE_CHECKING:
+ T = t.TypeVar("T")
+ Edit = t.Union[Insert, Remove, Move, Update, Keep]
+
@dataclass(frozen=True)
class Insert:
@@ -44,7 +51,7 @@ class Keep:
target: exp.Expression
-def diff(source, target):
+def diff(source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
"""
Returns the list of changes between the source and the target expressions.
@@ -89,25 +96,25 @@ class ChangeDistiller:
Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
"""
- def __init__(self, f=0.6, t=0.6):
+ def __init__(self, f: float = 0.6, t: float = 0.6) -> None:
self.f = f
self.t = t
self._sql_generator = Dialect().generator()
- def diff(self, source, target):
+ def diff(self, source: exp.Expression, target: exp.Expression) -> t.List[Edit]:
self._source = source
self._target = target
self._source_index = {id(n[0]): n[0] for n in source.bfs()}
self._target_index = {id(n[0]): n[0] for n in target.bfs()}
self._unmatched_source_nodes = set(self._source_index)
self._unmatched_target_nodes = set(self._target_index)
- self._bigram_histo_cache = {}
+ self._bigram_histo_cache: t.Dict[int, t.DefaultDict[str, int]] = {}
matching_set = self._compute_matching_set()
return self._generate_edit_script(matching_set)
- def _generate_edit_script(self, matching_set):
- edit_script = []
+ def _generate_edit_script(self, matching_set: t.Set[t.Tuple[int, int]]) -> t.List[Edit]:
+ edit_script: t.List[Edit] = []
for removed_node_id in self._unmatched_source_nodes:
edit_script.append(Remove(self._source_index[removed_node_id]))
for inserted_node_id in self._unmatched_target_nodes:
@@ -125,7 +132,9 @@ class ChangeDistiller:
return edit_script
- def _generate_move_edits(self, source, target, matching_set):
+ def _generate_move_edits(
+ self, source: exp.Expression, target: exp.Expression, matching_set: t.Set[t.Tuple[int, int]]
+ ) -> t.List[Move]:
source_args = [id(e) for e in _expression_only_args(source)]
target_args = [id(e) for e in _expression_only_args(target)]
@@ -138,7 +147,7 @@ class ChangeDistiller:
return move_edits
- def _compute_matching_set(self):
+ def _compute_matching_set(self) -> t.Set[t.Tuple[int, int]]:
leaves_matching_set = self._compute_leaf_matching_set()
matching_set = leaves_matching_set.copy()
@@ -183,8 +192,8 @@ class ChangeDistiller:
return matching_set
- def _compute_leaf_matching_set(self):
- candidate_matchings = []
+ def _compute_leaf_matching_set(self) -> t.Set[t.Tuple[int, int]]:
+ candidate_matchings: t.List[t.Tuple[float, int, exp.Expression, exp.Expression]] = []
source_leaves = list(_get_leaves(self._source))
target_leaves = list(_get_leaves(self._target))
for source_leaf in source_leaves:
@@ -216,7 +225,7 @@ class ChangeDistiller:
return matching_set
- def _dice_coefficient(self, source, target):
+ def _dice_coefficient(self, source: exp.Expression, target: exp.Expression) -> float:
source_histo = self._bigram_histo(source)
target_histo = self._bigram_histo(target)
@@ -231,13 +240,13 @@ class ChangeDistiller:
return 2 * overlap_len / total_grams
- def _bigram_histo(self, expression):
+ def _bigram_histo(self, expression: exp.Expression) -> t.DefaultDict[str, int]:
if id(expression) in self._bigram_histo_cache:
return self._bigram_histo_cache[id(expression)]
expression_str = self._sql_generator.generate(expression)
count = max(0, len(expression_str) - 1)
- bigram_histo = defaultdict(int)
+ bigram_histo: t.DefaultDict[str, int] = defaultdict(int)
for i in range(count):
bigram_histo[expression_str[i : i + 2]] += 1
@@ -245,7 +254,7 @@ class ChangeDistiller:
return bigram_histo
-def _get_leaves(expression):
+def _get_leaves(expression: exp.Expression) -> t.Generator[exp.Expression, None, None]:
has_child_exprs = False
for a in expression.args.values():
@@ -258,7 +267,7 @@ def _get_leaves(expression):
yield expression
-def _is_same_type(source, target):
+def _is_same_type(source: exp.Expression, target: exp.Expression) -> bool:
if type(source) is type(target):
if isinstance(source, exp.Join):
return source.args.get("side") == target.args.get("side")
@@ -271,15 +280,17 @@ def _is_same_type(source, target):
return False
-def _expression_only_args(expression):
- args = []
+def _expression_only_args(expression: exp.Expression) -> t.List[exp.Expression]:
+ args: t.List[t.Union[exp.Expression, t.List]] = []
if expression:
for a in expression.args.values():
args.extend(ensure_collection(a))
return [a for a in args if isinstance(a, exp.Expression)]
-def _lcs(seq_a, seq_b, equal):
+def _lcs(
+ seq_a: t.Sequence[T], seq_b: t.Sequence[T], equal: t.Callable[[T, T], bool]
+) -> t.Sequence[t.Optional[T]]:
"""Calculates the longest common subsequence"""
len_a = len(seq_a)
@@ -289,14 +300,14 @@ def _lcs(seq_a, seq_b, equal):
for i in range(len_a + 1):
for j in range(len_b + 1):
if i == 0 or j == 0:
- lcs_result[i][j] = []
+ lcs_result[i][j] = [] # type: ignore
elif equal(seq_a[i - 1], seq_b[j - 1]):
- lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
+ lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] # type: ignore
else:
lcs_result[i][j] = (
lcs_result[i - 1][j]
- if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
+ if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) # type: ignore
else lcs_result[i][j - 1]
)
- return lcs_result[len_a][len_b]
+ return lcs_result[len_a][len_b] # type: ignore
diff --git a/sqlglot/errors.py b/sqlglot/errors.py
index 2ef908f..23a08bd 100644
--- a/sqlglot/errors.py
+++ b/sqlglot/errors.py
@@ -37,6 +37,10 @@ class SchemaError(SqlglotError):
pass
+class ExecuteError(SqlglotError):
+ pass
+
+
def concat_errors(errors: t.Sequence[t.Any], maximum: int) -> str:
msg = [str(e) for e in errors[:maximum]]
remaining = len(errors) - maximum
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
index e765616..04621b5 100644
--- a/sqlglot/executor/__init__.py
+++ b/sqlglot/executor/__init__.py
@@ -1,20 +1,23 @@
import logging
import time
-from sqlglot import parse_one
+from sqlglot import maybe_parse
+from sqlglot.errors import ExecuteError
from sqlglot.executor.python import PythonExecutor
+from sqlglot.executor.table import Table, ensure_tables
from sqlglot.optimizer import optimize
from sqlglot.planner import Plan
+from sqlglot.schema import ensure_schema
logger = logging.getLogger("sqlglot")
-def execute(sql, schema, read=None):
+def execute(sql, schema=None, read=None, tables=None):
"""
Run a sql query against data.
Args:
- sql (str): a sql statement
+ sql (str|sqlglot.Expression): a sql statement
schema (dict|sqlglot.optimizer.Schema): database schema.
This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
the following forms:
@@ -23,10 +26,20 @@ def execute(sql, schema, read=None):
3. {catalog: {db: {table: {col: type}}}}
read (str): the SQL dialect to apply during parsing
(eg. "spark", "hive", "presto", "mysql").
+ tables (dict): additional tables to register.
Returns:
sqlglot.executor.Table: Simple columnar data structure.
"""
- expression = parse_one(sql, read=read)
+ tables = ensure_tables(tables)
+ if not schema:
+ schema = {
+ name: {column: type(table[0][column]).__name__ for column in table.columns}
+ for name, table in tables.mapping.items()
+ }
+ schema = ensure_schema(schema)
+ if tables.supported_table_args and tables.supported_table_args != schema.supported_table_args:
+ raise ExecuteError("Tables must support the same table args as schema")
+ expression = maybe_parse(sql, dialect=read)
now = time.time()
expression = optimize(expression, schema, leave_tables_isolated=True)
logger.debug("Optimization finished: %f", time.time() - now)
@@ -34,6 +47,6 @@ def execute(sql, schema, read=None):
plan = Plan(expression)
logger.debug("Logical Plan: %s", plan)
now = time.time()
- result = PythonExecutor().execute(plan)
+ result = PythonExecutor(tables=tables).execute(plan)
logger.debug("Query finished: %f", time.time() - now)
return result
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
index 393347b..e9ff75b 100644
--- a/sqlglot/executor/context.py
+++ b/sqlglot/executor/context.py
@@ -1,5 +1,12 @@
+from __future__ import annotations
+
+import typing as t
+
from sqlglot.executor.env import ENV
+if t.TYPE_CHECKING:
+ from sqlglot.executor.table import Table, TableIter
+
class Context:
"""
@@ -12,14 +19,14 @@ class Context:
evaluation of aggregation functions.
"""
- def __init__(self, tables, env=None):
+ def __init__(self, tables: t.Dict[str, Table], env: t.Optional[t.Dict] = None) -> None:
"""
Args
- tables (dict): table_name -> Table, representing the scope of the current execution context
- env (Optional[dict]): dictionary of functions within the execution context
+ tables: representing the scope of the current execution context.
+ env: dictionary of functions within the execution context.
"""
self.tables = tables
- self._table = None
+ self._table: t.Optional[Table] = None
self.range_readers = {name: table.range_reader for name, table in self.tables.items()}
self.row_readers = {name: table.reader for name, table in tables.items()}
self.env = {**(env or {}), "scope": self.row_readers}
@@ -31,7 +38,7 @@ class Context:
return tuple(self.eval(code) for code in codes)
@property
- def table(self):
+ def table(self) -> Table:
if self._table is None:
self._table = list(self.tables.values())[0]
for other in self.tables.values():
@@ -41,8 +48,12 @@ class Context:
raise Exception(f"Rows are different.")
return self._table
+ def add_columns(self, *columns: str) -> None:
+ for table in self.tables.values():
+ table.add_columns(*columns)
+
@property
- def columns(self):
+ def columns(self) -> t.Tuple:
return self.table.columns
def __iter__(self):
@@ -52,35 +63,39 @@ class Context:
reader = table[i]
yield reader, self
- def table_iter(self, table):
+ def table_iter(self, table: str) -> t.Generator[t.Tuple[TableIter, Context], None, None]:
self.env["scope"] = self.row_readers
for reader in self.tables[table]:
yield reader, self
- def sort(self, key):
- table = self.table
+ def filter(self, condition) -> None:
+ rows = [reader.row for reader, _ in self if self.eval(condition)]
- def sort_key(row):
- table.reader.row = row
+ for table in self.tables.values():
+ table.rows = rows
+
+ def sort(self, key) -> None:
+ def sort_key(row: t.Tuple) -> t.Tuple:
+ self.set_row(row)
return self.eval_tuple(key)
- table.rows.sort(key=sort_key)
+ self.table.rows.sort(key=sort_key)
- def set_row(self, row):
+ def set_row(self, row: t.Tuple) -> None:
for table in self.tables.values():
table.reader.row = row
self.env["scope"] = self.row_readers
- def set_index(self, index):
+ def set_index(self, index: int) -> None:
for table in self.tables.values():
table[index]
self.env["scope"] = self.row_readers
- def set_range(self, start, end):
+ def set_range(self, start: int, end: int) -> None:
for name in self.tables:
self.range_readers[name].range = range(start, end)
self.env["scope"] = self.range_readers
- def __contains__(self, table):
+ def __contains__(self, table: str) -> bool:
return table in self.tables
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
index bbe6c81..ed80cc9 100644
--- a/sqlglot/executor/env.py
+++ b/sqlglot/executor/env.py
@@ -1,7 +1,10 @@
import datetime
+import inspect
import re
import statistics
+from functools import wraps
+from sqlglot import exp
from sqlglot.helper import PYTHON_VERSION
@@ -16,20 +19,153 @@ class reverse_key:
return other.obj < self.obj
+def filter_nulls(func):
+ @wraps(func)
+ def _func(values):
+ return func(v for v in values if v is not None)
+
+ return _func
+
+
+def null_if_any(*required):
+ """
+ Decorator that makes a function return `None` if any of the `required` arguments are `None`.
+
+ This also supports decoration with no arguments, e.g.:
+
+ @null_if_any
+ def foo(a, b): ...
+
+ In which case all arguments are required.
+ """
+ f = None
+ if len(required) == 1 and callable(required[0]):
+ f = required[0]
+ required = ()
+
+ def decorator(func):
+ if required:
+ required_indices = [
+ i for i, param in enumerate(inspect.signature(func).parameters) if param in required
+ ]
+
+ def predicate(*args):
+ return any(args[i] is None for i in required_indices)
+
+ else:
+
+ def predicate(*args):
+ return any(a is None for a in args)
+
+ @wraps(func)
+ def _func(*args):
+ if predicate(*args):
+ return None
+ return func(*args)
+
+ return _func
+
+ if f:
+ return decorator(f)
+
+ return decorator
+
+
+@null_if_any("substr", "this")
+def str_position(substr, this, position=None):
+ position = position - 1 if position is not None else position
+ return this.find(substr, position) + 1
+
+
+@null_if_any("this")
+def substring(this, start=None, length=None):
+ if start is None:
+ return this
+ elif start == 0:
+ return ""
+ elif start < 0:
+ start = len(this) + start
+ else:
+ start -= 1
+
+ end = None if length is None else start + length
+
+ return this[start:end]
+
+
+@null_if_any
+def cast(this, to):
+ if to == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(this)
+ if to == exp.DataType.Type.DATETIME:
+ return datetime.datetime.fromisoformat(this)
+ if to in exp.DataType.TEXT_TYPES:
+ return str(this)
+ if to in {exp.DataType.Type.FLOAT, exp.DataType.Type.DOUBLE}:
+ return float(this)
+ if to in exp.DataType.NUMERIC_TYPES:
+ return int(this)
+ raise NotImplementedError(f"Casting to '{to}' not implemented.")
+
+
+def ordered(this, desc, nulls_first):
+ if desc:
+ return reverse_key(this)
+ return this
+
+
+@null_if_any
+def interval(this, unit):
+ if unit == "DAY":
+ return datetime.timedelta(days=float(this))
+ raise NotImplementedError
+
+
ENV = {
"__builtins__": {},
- "datetime": datetime,
- "locals": locals,
- "re": re,
- "bool": bool,
- "float": float,
- "int": int,
- "str": str,
- "desc": reverse_key,
- "SUM": sum,
- "AVG": statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean, # type: ignore
- "COUNT": lambda acc: sum(1 for e in acc if e is not None),
- "MAX": max,
- "MIN": min,
+ "exp": exp,
+ # aggs
+ "SUM": filter_nulls(sum),
+ "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore
+ "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc)),
+ "MAX": filter_nulls(max),
+ "MIN": filter_nulls(min),
+ # scalar functions
+ "ABS": null_if_any(lambda this: abs(this)),
+ "ADD": null_if_any(lambda e, this: e + this),
+ "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
+ "BITWISEAND": null_if_any(lambda this, e: this & e),
+ "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
+ "BITWISEOR": null_if_any(lambda this, e: this | e),
+ "BITWISERIGHTSHIFT": null_if_any(lambda this, e: this >> e),
+ "BITWISEXOR": null_if_any(lambda this, e: this ^ e),
+ "CAST": cast,
+ "COALESCE": lambda *args: next((a for a in args if a is not None), None),
+ "CONCAT": null_if_any(lambda *args: "".join(args)),
+ "CONCATWS": null_if_any(lambda this, *args: this.join(args)),
+ "DIV": null_if_any(lambda e, this: e / this),
+ "EQ": null_if_any(lambda this, e: this == e),
+ "EXTRACT": null_if_any(lambda this, e: getattr(e, this)),
+ "GT": null_if_any(lambda this, e: this > e),
+ "GTE": null_if_any(lambda this, e: this >= e),
+ "IFNULL": lambda e, alt: alt if e is None else e,
+ "IF": lambda predicate, true, false: true if predicate else false,
+ "INTDIV": null_if_any(lambda e, this: e // this),
+ "INTERVAL": interval,
+ "LIKE": null_if_any(
+ lambda this, e: bool(re.match(e.replace("_", ".").replace("%", ".*"), this))
+ ),
+ "LOWER": null_if_any(lambda arg: arg.lower()),
+ "LT": null_if_any(lambda this, e: this < e),
+ "LTE": null_if_any(lambda this, e: this <= e),
+ "MOD": null_if_any(lambda e, this: e % this),
+ "MUL": null_if_any(lambda e, this: e * this),
+ "NEQ": null_if_any(lambda this, e: this != e),
+ "ORD": null_if_any(ord),
+ "ORDERED": ordered,
"POW": pow,
+ "STRPOSITION": str_position,
+ "SUB": null_if_any(lambda e, this: e - this),
+ "SUBSTRING": substring,
+ "UPPER": null_if_any(lambda arg: arg.upper()),
}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
index 7d1db32..cb2543c 100644
--- a/sqlglot/executor/python.py
+++ b/sqlglot/executor/python.py
@@ -5,16 +5,18 @@ import math
from sqlglot import exp, generator, planner, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql
+from sqlglot.errors import ExecuteError
from sqlglot.executor.context import Context
from sqlglot.executor.env import ENV
-from sqlglot.executor.table import Table
-from sqlglot.helper import csv_reader
+from sqlglot.executor.table import RowReader, Table
+from sqlglot.helper import csv_reader, subclasses
class PythonExecutor:
- def __init__(self, env=None):
- self.generator = Python().generator(identify=True)
+ def __init__(self, env=None, tables=None):
+ self.generator = Python().generator(identify=True, comments=False)
self.env = {**ENV, **(env or {})}
+ self.tables = tables or {}
def execute(self, plan):
running = set()
@@ -24,36 +26,41 @@ class PythonExecutor:
while queue:
node = queue.pop()
- context = self.context(
- {
- name: table
- for dep in node.dependencies
- for name, table in contexts[dep].tables.items()
- }
- )
- running.add(node)
-
- if isinstance(node, planner.Scan):
- contexts[node] = self.scan(node, context)
- elif isinstance(node, planner.Aggregate):
- contexts[node] = self.aggregate(node, context)
- elif isinstance(node, planner.Join):
- contexts[node] = self.join(node, context)
- elif isinstance(node, planner.Sort):
- contexts[node] = self.sort(node, context)
- else:
- raise NotImplementedError
-
- running.remove(node)
- finished.add(node)
-
- for dep in node.dependents:
- if dep not in running and all(d in contexts for d in dep.dependencies):
- queue.add(dep)
-
- for dep in node.dependencies:
- if all(d in finished for d in dep.dependents):
- contexts.pop(dep)
+ try:
+ context = self.context(
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
+ )
+ running.add(node)
+
+ if isinstance(node, planner.Scan):
+ contexts[node] = self.scan(node, context)
+ elif isinstance(node, planner.Aggregate):
+ contexts[node] = self.aggregate(node, context)
+ elif isinstance(node, planner.Join):
+ contexts[node] = self.join(node, context)
+ elif isinstance(node, planner.Sort):
+ contexts[node] = self.sort(node, context)
+ elif isinstance(node, planner.SetOperation):
+ contexts[node] = self.set_operation(node, context)
+ else:
+ raise NotImplementedError
+
+ running.remove(node)
+ finished.add(node)
+
+ for dep in node.dependents:
+ if dep not in running and all(d in contexts for d in dep.dependencies):
+ queue.add(dep)
+
+ for dep in node.dependencies:
+ if all(d in finished for d in dep.dependents):
+ contexts.pop(dep)
+ except Exception as e:
+ raise ExecuteError(f"Step '{node.id}' failed: {e}") from e
root = plan.root
return contexts[root].tables[root.name]
@@ -76,38 +83,43 @@ class PythonExecutor:
return Context(tables, env=self.env)
def table(self, expressions):
- return Table(expression.alias_or_name for expression in expressions)
+ return Table(
+ expression.alias_or_name if isinstance(expression, exp.Expression) else expression
+ for expression in expressions
+ )
def scan(self, step, context):
source = step.source
- if isinstance(source, exp.Expression):
+ if source and isinstance(source, exp.Expression):
source = source.name or source.alias
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
- if source in context:
+ if source is None:
+ context, table_iter = self.static()
+ elif source in context:
if not projections and not condition:
return self.context({step.name: context.tables[source]})
table_iter = context.table_iter(source)
- else:
+ elif isinstance(step.source, exp.Table) and isinstance(step.source.this, exp.ReadCSV):
table_iter = self.scan_csv(step)
+ context = next(table_iter)
+ else:
+ context, table_iter = self.scan_table(step)
if projections:
sink = self.table(step.projections)
else:
- sink = None
-
- for reader, ctx in table_iter:
- if sink is None:
- sink = Table(reader.columns)
+ sink = self.table(context.columns)
- if condition and not ctx.eval(condition):
+ for reader in table_iter:
+ if condition and not context.eval(condition):
continue
if projections:
- sink.append(ctx.eval_tuple(projections))
+ sink.append(context.eval_tuple(projections))
else:
sink.append(reader.row)
@@ -116,14 +128,23 @@ class PythonExecutor:
return self.context({step.name: sink})
+ def static(self):
+ return self.context({}), [RowReader(())]
+
+ def scan_table(self, step):
+ table = self.tables.find(step.source)
+ context = self.context({step.source.alias_or_name: table})
+ return context, iter(table)
+
def scan_csv(self, step):
- source = step.source
- alias = source.alias
+ alias = step.source.alias
+ source = step.source.this
with csv_reader(source) as reader:
columns = next(reader)
table = Table(columns)
context = self.context({alias: table})
+ yield context
types = []
for row in reader:
@@ -134,7 +155,7 @@ class PythonExecutor:
except (ValueError, SyntaxError):
types.append(str)
context.set_row(tuple(t(v) for t, v in zip(types, row)))
- yield context.table.reader, context
+ yield context.table.reader
def join(self, step, context):
source = step.name
@@ -160,16 +181,19 @@ class PythonExecutor:
for name, column_range in column_ranges.items()
}
)
+ condition = self.generate(join["condition"])
+ if condition:
+ source_context.filter(condition)
condition = self.generate(step.condition)
projections = self.generate_tuple(step.projections)
- if not condition or not projections:
+ if not condition and not projections:
return source_context
sink = self.table(step.projections if projections else source_context.columns)
- for reader, ctx in join_context:
+ for reader, ctx in source_context:
if condition and not ctx.eval(condition):
continue
@@ -181,7 +205,15 @@ class PythonExecutor:
if len(sink) >= step.limit:
break
- return self.context({step.name: sink})
+ if projections:
+ return self.context({step.name: sink})
+ else:
+ return self.context(
+ {
+ name: Table(table.columns, sink.rows, table.column_range)
+ for name, table in source_context.tables.items()
+ }
+ )
def nested_loop_join(self, _join, source_context, join_context):
table = Table(source_context.columns + join_context.columns)
@@ -195,6 +227,8 @@ class PythonExecutor:
def hash_join(self, join, source_context, join_context):
source_key = self.generate_tuple(join["source_key"])
join_key = self.generate_tuple(join["join_key"])
+ left = join.get("side") == "LEFT"
+ right = join.get("side") == "RIGHT"
results = collections.defaultdict(lambda: ([], []))
@@ -204,28 +238,47 @@ class PythonExecutor:
results[ctx.eval_tuple(join_key)][1].append(reader.row)
table = Table(source_context.columns + join_context.columns)
+ nulls = [(None,) * len(join_context.columns if left else source_context.columns)]
for a_group, b_group in results.values():
+ if left:
+ b_group = b_group or nulls
+ elif right:
+ a_group = a_group or nulls
+
for a_row, b_row in itertools.product(a_group, b_group):
table.append(a_row + b_row)
return table
def aggregate(self, step, context):
- source = step.source
- group_by = self.generate_tuple(step.group)
+ group_by = self.generate_tuple(step.group.values())
aggregations = self.generate_tuple(step.aggregations)
operands = self.generate_tuple(step.operands)
if operands:
- source_table = context.tables[source]
- operand_table = Table(source_table.columns + self.table(step.operands).columns)
+ operand_table = Table(self.table(step.operands).columns)
for reader, ctx in context:
- operand_table.append(reader.row + ctx.eval_tuple(operands))
+ operand_table.append(ctx.eval_tuple(operands))
+
+ for i, (a, b) in enumerate(zip(context.table.rows, operand_table.rows)):
+ context.table.rows[i] = a + b
+
+ width = len(context.columns)
+ context.add_columns(*operand_table.columns)
+
+ operand_table = Table(
+ context.columns,
+ context.table.rows,
+ range(width, width + len(operand_table.columns)),
+ )
context = self.context(
- {None: operand_table, **{table: operand_table for table in context.tables}}
+ {
+ None: operand_table,
+ **context.tables,
+ }
)
context.sort(group_by)
@@ -233,25 +286,22 @@ class PythonExecutor:
group = None
start = 0
end = 1
- length = len(context.tables[source])
- table = self.table(step.group + step.aggregations)
+ length = len(context.table)
+ table = self.table(list(step.group) + step.aggregations)
for i in range(length):
context.set_index(i)
key = context.eval_tuple(group_by)
group = key if group is None else group
end += 1
-
+ if key != group:
+ context.set_range(start, end - 2)
+ table.append(group + context.eval_tuple(aggregations))
+ group = key
+ start = end - 2
if i == length - 1:
context.set_range(start, end - 1)
- elif key != group:
- context.set_range(start, end - 2)
- else:
- continue
-
- table.append(group + context.eval_tuple(aggregations))
- group = key
- start = end - 2
+ table.append(group + context.eval_tuple(aggregations))
context = self.context({step.name: table, **{name: table for name in context.tables}})
@@ -262,60 +312,77 @@ class PythonExecutor:
def sort(self, step, context):
projections = self.generate_tuple(step.projections)
- sink = self.table(step.projections)
+ projection_columns = [p.alias_or_name for p in step.projections]
+ all_columns = list(context.columns) + projection_columns
+ sink = self.table(all_columns)
for reader, ctx in context:
- sink.append(ctx.eval_tuple(projections))
+ sink.append(reader.row + ctx.eval_tuple(projections))
- context = self.context(
+ sort_ctx = self.context(
{
None: sink,
**{table: sink for table in context.tables},
}
)
- context.sort(self.generate_tuple(step.key))
+ sort_ctx.sort(self.generate_tuple(step.key))
if not math.isinf(step.limit):
- context.table.rows = context.table.rows[0 : step.limit]
+ sort_ctx.table.rows = sort_ctx.table.rows[0 : step.limit]
- return self.context({step.name: context.table})
+ output = Table(
+ projection_columns,
+ rows=[r[len(context.columns) : len(all_columns)] for r in sort_ctx.table.rows],
+ )
+ return self.context({step.name: output})
+ def set_operation(self, step, context):
+ left = context.tables[step.left]
+ right = context.tables[step.right]
-def _cast_py(self, expression):
- to = expression.args["to"].this
- this = self.sql(expression, "this")
+ sink = self.table(left.columns)
+
+ if issubclass(step.op, exp.Intersect):
+ sink.rows = list(set(left.rows).intersection(set(right.rows)))
+ elif issubclass(step.op, exp.Except):
+ sink.rows = list(set(left.rows).difference(set(right.rows)))
+ elif issubclass(step.op, exp.Union) and step.distinct:
+ sink.rows = list(set(left.rows).union(set(right.rows)))
+ else:
+ sink.rows = left.rows + right.rows
- if to == exp.DataType.Type.DATE:
- return f"datetime.date.fromisoformat({this})"
- if to == exp.DataType.Type.TEXT:
- return f"str({this})"
- raise NotImplementedError
+ return self.context({step.name: sink})
-def _column_py(self, expression):
- table = self.sql(expression, "table") or None
+def _ordered_py(self, expression):
this = self.sql(expression, "this")
- return f"scope[{table}][{this}]"
+ desc = "True" if expression.args.get("desc") else "False"
+ nulls_first = "True" if expression.args.get("nulls_first") else "False"
+ return f"ORDERED({this}, {desc}, {nulls_first})"
-def _interval_py(self, expression):
- this = self.sql(expression, "this")
- unit = expression.text("unit").upper()
- if unit == "DAY":
- return f"datetime.timedelta(days=float({this}))"
- raise NotImplementedError
+def _rename(self, e):
+ try:
+ if "expressions" in e.args:
+ this = self.sql(e, "this")
+ this = f"{this}, " if this else ""
+ return f"{e.key.upper()}({this}{self.expressions(e)})"
+ return f"{e.key.upper()}({self.format_args(*e.args.values())})"
+ except Exception as ex:
+ raise Exception(f"Could not rename {repr(e)}") from ex
-def _like_py(self, expression):
+def _case_sql(self, expression):
this = self.sql(expression, "this")
- expression = self.sql(expression, "expression")
- return f"""bool(re.match({expression}.replace("_", ".").replace("%", ".*"), {this}))"""
+ chain = self.sql(expression, "default") or "None"
+ for e in reversed(expression.args["ifs"]):
+ true = self.sql(e, "true")
+ condition = self.sql(e, "this")
+ condition = f"{this} = ({condition})" if this else condition
+ chain = f"{true} if {condition} else ({chain})"
-def _ordered_py(self, expression):
- this = self.sql(expression, "this")
- desc = expression.args.get("desc")
- return f"desc({this})" if desc else this
+ return chain
class Python(Dialect):
@@ -324,32 +391,22 @@ class Python(Dialect):
class Generator(generator.Generator):
TRANSFORMS = {
+ **{klass: _rename for klass in subclasses(exp.__name__, exp.Binary)},
+ **{klass: _rename for klass in exp.ALL_FUNCTIONS},
+ exp.Case: _case_sql,
exp.Alias: lambda self, e: self.sql(e.this),
exp.Array: inline_array_sql,
exp.And: lambda self, e: self.binary(e, "and"),
+ exp.Between: _rename,
exp.Boolean: lambda self, e: "True" if e.this else "False",
- exp.Cast: _cast_py,
- exp.Column: _column_py,
- exp.EQ: lambda self, e: self.binary(e, "=="),
+ exp.Cast: lambda self, e: f"CAST({self.sql(e.this)}, exp.DataType.Type.{e.args['to']})",
+ exp.Column: lambda self, e: f"scope[{self.sql(e, 'table') or None}][{self.sql(e.this)}]",
+ exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})",
exp.In: lambda self, e: f"{self.sql(e, 'this')} in {self.expressions(e)}",
- exp.Interval: _interval_py,
exp.Is: lambda self, e: self.binary(e, "is"),
- exp.Like: _like_py,
exp.Not: lambda self, e: f"not {self.sql(e.this)}",
exp.Null: lambda *_: "None",
exp.Or: lambda self, e: self.binary(e, "or"),
exp.Ordered: _ordered_py,
exp.Star: lambda *_: "1",
}
-
- def case_sql(self, expression):
- this = self.sql(expression, "this")
- chain = self.sql(expression, "default") or "None"
-
- for e in reversed(expression.args["ifs"]):
- true = self.sql(e, "true")
- condition = self.sql(e, "this")
- condition = f"{this} = ({condition})" if this else condition
- chain = f"{true} if {condition} else ({chain})"
-
- return chain
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
index 6796740..f1b5b54 100644
--- a/sqlglot/executor/table.py
+++ b/sqlglot/executor/table.py
@@ -1,14 +1,27 @@
+from __future__ import annotations
+
+from sqlglot.helper import dict_depth
+from sqlglot.schema import AbstractMappingSchema
+
+
class Table:
def __init__(self, columns, rows=None, column_range=None):
self.columns = tuple(columns)
self.column_range = column_range
self.reader = RowReader(self.columns, self.column_range)
-
self.rows = rows or []
if rows:
assert len(rows[0]) == len(self.columns)
self.range_reader = RangeReader(self)
+ def add_columns(self, *columns: str) -> None:
+ self.columns += columns
+ if self.column_range:
+ self.column_range = range(
+ self.column_range.start, self.column_range.stop + len(columns)
+ )
+ self.reader = RowReader(self.columns, self.column_range)
+
def append(self, row):
assert len(row) == len(self.columns)
self.rows.append(row)
@@ -87,3 +100,31 @@ class RowReader:
def __getitem__(self, column):
return self.row[self.columns[column]]
+
+
+class Tables(AbstractMappingSchema[Table]):
+ pass
+
+
+def ensure_tables(d: dict | None) -> Tables:
+ return Tables(_ensure_tables(d))
+
+
+def _ensure_tables(d: dict | None) -> dict:
+ if not d:
+ return {}
+
+ depth = dict_depth(d)
+
+ if depth > 1:
+ return {k: _ensure_tables(v) for k, v in d.items()}
+
+ result = {}
+ for name, table in d.items():
+ if isinstance(table, Table):
+ result[name] = table
+ else:
+ columns = tuple(table[0]) if table else ()
+ rows = [tuple(row[c] for c in columns) for row in table]
+ result[name] = Table(columns=columns, rows=rows)
+ return result
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 57a2c88..beafca8 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -641,9 +641,11 @@ class Set(Expression):
class SetItem(Expression):
arg_types = {
- "this": True,
+ "this": False,
+ "expressions": False,
"kind": False,
"collate": False, # MySQL SET NAMES statement
+ "global": False,
}
@@ -787,6 +789,7 @@ class Drop(Expression):
"exists": False,
"temporary": False,
"materialized": False,
+ "cascade": False,
}
@@ -1073,6 +1076,18 @@ class FileFormatProperty(Property):
pass
+class DistKeyProperty(Property):
+ pass
+
+
+class SortKeyProperty(Property):
+ pass
+
+
+class DistStyleProperty(Property):
+ pass
+
+
class LocationProperty(Property):
pass
@@ -1130,6 +1145,9 @@ class Properties(Expression):
"LOCATION": LocationProperty,
"PARTITIONED_BY": PartitionedByProperty,
"TABLE_FORMAT": TableFormatProperty,
+ "DISTKEY": DistKeyProperty,
+ "DISTSTYLE": DistStyleProperty,
+ "SORTKEY": SortKeyProperty,
}
@classmethod
@@ -1356,7 +1374,7 @@ class Var(Expression):
class Schema(Expression):
- arg_types = {"this": False, "expressions": True}
+ arg_types = {"this": False, "expressions": False}
class Select(Subqueryable):
@@ -1741,7 +1759,7 @@ class Select(Subqueryable):
)
if join_alias:
- join.set("this", alias_(join.args["this"], join_alias, table=True))
+ join.set("this", alias_(join.this, join_alias, table=True))
return _apply_list_builder(
join,
instance=self,
@@ -1884,6 +1902,7 @@ class Subquery(DerivedTable, Unionable):
arg_types = {
"this": True,
"alias": False,
+ "with": False,
**QUERY_MODIFIERS,
}
@@ -2025,6 +2044,31 @@ class DataType(Expression):
NULL = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
+ TEXT_TYPES = {
+ Type.CHAR,
+ Type.NCHAR,
+ Type.VARCHAR,
+ Type.NVARCHAR,
+ Type.TEXT,
+ }
+
+ NUMERIC_TYPES = {
+ Type.INT,
+ Type.TINYINT,
+ Type.SMALLINT,
+ Type.BIGINT,
+ Type.FLOAT,
+ Type.DOUBLE,
+ }
+
+ TEMPORAL_TYPES = {
+ Type.TIMESTAMP,
+ Type.TIMESTAMPTZ,
+ Type.TIMESTAMPLTZ,
+ Type.DATE,
+ Type.DATETIME,
+ }
+
@classmethod
def build(cls, dtype, **kwargs) -> DataType:
return DataType(
@@ -2054,16 +2098,25 @@ class Exists(SubqueryPredicate):
pass
-# Commands to interact with the databases or engines
-# These expressions don't truly parse the expression and consume
-# whatever exists as a string until the end or a semicolon
+# Commands to interact with the databases or engines. For most of the command
+# expressions we parse whatever comes after the command's name as a string.
class Command(Expression):
arg_types = {"this": True, "expression": False}
-# Binary Expressions
-# (ADD a b)
-# (FROM table selects)
+class Transaction(Command):
+ arg_types = {"this": False, "modes": False}
+
+
+class Commit(Command):
+ arg_types = {} # type: ignore
+
+
+class Rollback(Command):
+ arg_types = {"savepoint": False}
+
+
+# Binary expressions like (ADD a b)
class Binary(Expression):
arg_types = {"this": True, "expression": True}
@@ -2215,7 +2268,7 @@ class Not(Unary, Condition):
class Paren(Unary, Condition):
- pass
+ arg_types = {"this": True, "with": False}
class Neg(Unary):
@@ -2428,6 +2481,10 @@ class Cast(Func):
return self.args["to"]
+class Collate(Binary):
+ pass
+
+
class TryCast(Cast):
pass
@@ -2442,13 +2499,17 @@ class Coalesce(Func):
is_var_len_args = True
-class ConcatWs(Func):
- arg_types = {"expressions": False}
+class Concat(Func):
+ arg_types = {"expressions": True}
is_var_len_args = True
+class ConcatWs(Concat):
+ _sql_names = ["CONCAT_WS"]
+
+
class Count(AggFunc):
- pass
+ arg_types = {"this": False}
class CurrentDate(Func):
@@ -2556,10 +2617,18 @@ class Day(Func):
pass
+class Decode(Func):
+ arg_types = {"this": True, "charset": True}
+
+
class DiToDate(Func):
pass
+class Encode(Func):
+ arg_types = {"this": True, "charset": True}
+
+
class Exp(Func):
pass
@@ -2581,6 +2650,10 @@ class GroupConcat(Func):
arg_types = {"this": True, "separator": False}
+class Hex(Func):
+ pass
+
+
class If(Func):
arg_types = {"this": True, "true": True, "false": False}
@@ -2641,7 +2714,7 @@ class Log10(Func):
class Lower(Func):
- pass
+ _sql_names = ["LOWER", "LCASE"]
class Map(Func):
@@ -2686,6 +2759,12 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False}
+class ReadCSV(Func):
+ _sql_names = ["READ_CSV"]
+ is_var_len_args = True
+ arg_types = {"this": True, "expressions": False}
+
+
class Reduce(Func):
arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
@@ -2804,8 +2883,8 @@ class TimeStrToUnix(Func):
class Trim(Func):
arg_types = {
"this": True,
- "position": False,
"expression": False,
+ "position": False,
"collation": False,
}
@@ -2826,6 +2905,10 @@ class TsOrDiToDi(Func):
pass
+class Unhex(Func):
+ pass
+
+
class UnixToStr(Func):
arg_types = {"this": True, "format": False}
@@ -2843,7 +2926,7 @@ class UnixToTimeStr(Func):
class Upper(Func):
- pass
+ _sql_names = ["UPPER", "UCASE"]
class Variance(AggFunc):
@@ -3701,6 +3784,19 @@ def replace_placeholders(expression, *args, **kwargs):
return expression.transform(_replace_placeholders, iter(args), **kwargs)
+def true():
+ return Boolean(this=True)
+
+
+def false():
+ return Boolean(this=False)
+
+
+def null():
+ return Null()
+
+
+# TODO: deprecate this
TRUE = Boolean(this=True)
FALSE = Boolean(this=False)
NULL = Null()
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 11d9073..ffb34eb 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -67,7 +67,7 @@ class Generator:
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
exp.ExecuteAsProperty: lambda self, e: self.naked_property(e),
- exp.VolatilityProperty: lambda self, e: self.sql(e.name),
+ exp.VolatilityProperty: lambda self, e: e.name,
}
# Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
@@ -94,6 +94,9 @@ class Generator:
ROOT_PROPERTIES = {
exp.ReturnsProperty,
exp.LanguageProperty,
+ exp.DistStyleProperty,
+ exp.DistKeyProperty,
+ exp.SortKeyProperty,
}
WITH_PROPERTIES = {
@@ -241,7 +244,7 @@ class Generator:
if not NEWLINE_RE.search(comment):
return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
- return f"/*{comment}*/\n{sql}"
+ return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
def wrap(self, expression):
this_sql = self.indent(
@@ -475,7 +478,8 @@ class Generator:
exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}"
+ cascade = " CASCADE" if expression.args.get("cascade") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
def except_sql(self, expression):
return self.prepend_ctes(
@@ -915,13 +919,15 @@ class Generator:
def subquery_sql(self, expression):
alias = self.sql(expression, "alias")
- return self.query_modifiers(
+ sql = self.query_modifiers(
expression,
self.wrap(expression),
self.expressions(expression, key="pivots", sep=" "),
f" AS {alias}" if alias else "",
)
+ return self.prepend_ctes(expression, sql)
+
def qualify_sql(self, expression):
this = self.indent(self.sql(expression, "this"))
return f"{self.seg('QUALIFY')}{self.sep()}{this}"
@@ -1111,9 +1117,12 @@ class Generator:
def paren_sql(self, expression):
if isinstance(expression.unnest(), exp.Select):
- return self.wrap(expression)
- sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
- return f"({sql}{self.seg(')', sep='')}"
+ sql = self.wrap(expression)
+ else:
+ sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
+ sql = f"({sql}{self.seg(')', sep='')}"
+
+ return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
return f"-{self.sql(expression, 'this')}"
@@ -1173,9 +1182,23 @@ class Generator:
zone = self.sql(expression, "this")
return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
+ def collate_sql(self, expression):
+ return self.binary(expression, "COLLATE")
+
def command_sql(self, expression):
return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
+ def transaction_sql(self, *_):
+ return "BEGIN"
+
+ def commit_sql(self, *_):
+ return "COMMIT"
+
+ def rollback_sql(self, expression):
+ savepoint = expression.args.get("savepoint")
+ savepoint = f" TO {savepoint}" if savepoint else ""
+ return f"ROLLBACK{savepoint}"
+
def distinct_sql(self, expression):
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1193,10 +1216,7 @@ class Generator:
def intdiv_sql(self, expression):
return self.sql(
exp.Cast(
- this=exp.Div(
- this=expression.args["this"],
- expression=expression.args["expression"],
- ),
+ this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
index 379c2e7..8c5808d 100644
--- a/sqlglot/helper.py
+++ b/sqlglot/helper.py
@@ -11,7 +11,8 @@ from copy import copy
from enum import Enum
if t.TYPE_CHECKING:
- from sqlglot.expressions import Expression, Table
+ from sqlglot import exp
+ from sqlglot.expressions import Expression
T = t.TypeVar("T")
E = t.TypeVar("E", bound=Expression)
@@ -150,7 +151,7 @@ def apply_index_offset(expressions: t.List[E], offset: int) -> t.List[E]:
if expression.is_int:
expression = expression.copy()
logger.warning("Applying array index offset (%s)", offset)
- expression.args["this"] = str(int(expression.args["this"]) + offset)
+ expression.args["this"] = str(int(expression.this) + offset)
return [expression]
return expressions
@@ -228,19 +229,18 @@ def open_file(file_name: str) -> t.TextIO:
@contextmanager
-def csv_reader(table: Table) -> t.Any:
+def csv_reader(read_csv: exp.ReadCSV) -> t.Any:
"""
Returns a csv reader given the expression `READ_CSV(name, ['delimiter', '|', ...])`.
Args:
- table: a `Table` expression with an anonymous function `READ_CSV` in it.
+ read_csv: a `ReadCSV` function call
Yields:
A python csv reader.
"""
- file, *args = table.this.expressions
- file = file.name
- file = open_file(file)
+ args = read_csv.expressions
+ file = open_file(read_csv.name)
delimiter = ","
args = iter(arg.name for arg in args)
@@ -354,3 +354,34 @@ def flatten(values: t.Iterable[t.Iterable[t.Any] | t.Any]) -> t.Generator[t.Any,
yield from flatten(value)
else:
yield value
+
+
+def dict_depth(d: t.Dict) -> int:
+ """
+ Get the nesting depth of a dictionary.
+
+ For example:
+ >>> dict_depth(None)
+ 0
+ >>> dict_depth({})
+ 1
+ >>> dict_depth({"a": "b"})
+ 1
+ >>> dict_depth({"a": {}})
+ 2
+ >>> dict_depth({"a": {"b": {}}})
+ 3
+
+ Args:
+ d (dict): dictionary
+ Returns:
+ int: depth
+ """
+ try:
+ return 1 + dict_depth(next(iter(d.values())))
+ except AttributeError:
+ # d doesn't have attribute "values"
+ return 0
+ except StopIteration:
+ # d.values() returns an empty sequence
+ return 1
diff --git a/sqlglot/optimizer/annotate_types.py b/sqlglot/optimizer/annotate_types.py
index 96331e2..191ea52 100644
--- a/sqlglot/optimizer/annotate_types.py
+++ b/sqlglot/optimizer/annotate_types.py
@@ -245,23 +245,31 @@ class TypeAnnotator:
def annotate(self, expression):
if isinstance(expression, self.TRAVERSABLES):
for scope in traverse_scope(expression):
- subscope_selects = {
- name: {select.alias_or_name: select for select in source.selects}
- for name, source in scope.sources.items()
- if isinstance(source, Scope)
- }
-
+ selects = {}
+ for name, source in scope.sources.items():
+ if not isinstance(source, Scope):
+ continue
+ if isinstance(source.expression, exp.Values):
+ selects[name] = {
+ alias: column
+ for alias, column in zip(
+ source.expression.alias_column_names,
+ source.expression.expressions[0].expressions,
+ )
+ }
+ else:
+ selects[name] = {
+ select.alias_or_name: select for select in source.expression.selects
+ }
# First annotate the current scope's column references
for col in scope.columns:
source = scope.sources[col.table]
if isinstance(source, exp.Table):
col.type = self.schema.get_column_type(source, col)
else:
- col.type = subscope_selects[col.table][col.name].type
-
+ col.type = selects[col.table][col.name].type
# Then (possibly) annotate the remaining expressions in the scope
self._maybe_annotate(scope.expression)
-
return self._maybe_annotate(expression) # This takes care of non-traversable expressions
def _maybe_annotate(self, expression):
diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py
new file mode 100644
index 0000000..9b3d98a
--- /dev/null
+++ b/sqlglot/optimizer/canonicalize.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import exp
+
+
+def canonicalize(expression: exp.Expression) -> exp.Expression:
+ """Converts a sql expression into a standard form.
+
+ This method relies on annotate_types because many of the
+ conversions rely on type inference.
+
+ Args:
+ expression: The expression to canonicalize.
+ """
+ exp.replace_children(expression, canonicalize)
+ expression = add_text_to_concat(expression)
+ expression = coerce_type(expression)
+ return expression
+
+
+def add_text_to_concat(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Add) and node.type in exp.DataType.TEXT_TYPES:
+ node = exp.Concat(this=node.this, expression=node.expression)
+ return node
+
+
+def coerce_type(node: exp.Expression) -> exp.Expression:
+ if isinstance(node, exp.Binary):
+ _coerce_date(node.left, node.right)
+ elif isinstance(node, exp.Between):
+ _coerce_date(node.this, node.args["low"])
+ elif isinstance(node, exp.Extract):
+ if node.expression.type not in exp.DataType.TEMPORAL_TYPES:
+ _replace_cast(node.expression, "datetime")
+ return node
+
+
+def _coerce_date(a: exp.Expression, b: exp.Expression) -> None:
+ for a, b in itertools.permutations([a, b]):
+ if a.type == exp.DataType.Type.DATE and b.type != exp.DataType.Type.DATE:
+ _replace_cast(b, "date")
+
+
+def _replace_cast(node: exp.Expression, to: str) -> None:
+ data_type = exp.DataType.build(to)
+ cast = exp.Cast(this=node.copy(), to=data_type)
+ cast.type = data_type
+ node.replace(cast)
diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py
index 29621af..de4e011 100644
--- a/sqlglot/optimizer/eliminate_joins.py
+++ b/sqlglot/optimizer/eliminate_joins.py
@@ -128,8 +128,8 @@ def join_condition(join):
Tuple of (source key, join key, remaining predicate)
"""
name = join.this.alias_or_name
- on = join.args.get("on") or exp.TRUE
- on = on.copy()
+ on = (join.args.get("on") or exp.true()).copy()
+ on = on if isinstance(on, exp.And) else exp.and_(on, exp.true())
source_key = []
join_key = []
@@ -141,7 +141,7 @@ def join_condition(join):
#
# should pull y.b as the join key and x.a as the source key
if normalized(on):
- for condition in on.flatten() if isinstance(on, exp.And) else [on]:
+ for condition in on.flatten():
if isinstance(condition, exp.EQ):
left, right = condition.unnest_operands()
left_tables = exp.column_table_names(left)
@@ -150,13 +150,12 @@ def join_condition(join):
if name in left_tables and name not in right_tables:
join_key.append(left)
source_key.append(right)
- condition.replace(exp.TRUE)
+ condition.replace(exp.true())
elif name in right_tables and name not in left_tables:
join_key.append(right)
source_key.append(left)
- condition.replace(exp.TRUE)
+ condition.replace(exp.true())
on = simplify(on)
- remaining_condition = None if on == exp.TRUE else on
-
+ remaining_condition = None if on == exp.true() else on
return source_key, join_key, remaining_condition
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
index 40e4ab1..fd69832 100644
--- a/sqlglot/optimizer/optimize_joins.py
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -29,7 +29,7 @@ def optimize_joins(expression):
if isinstance(on, exp.Connector):
for predicate in on.flatten():
if name in exp.column_table_names(predicate):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
join.on(predicate, copy=False)
expression = reorder_joins(expression)
@@ -70,6 +70,6 @@ def normalize(expression):
def other_table_names(join, exclude):
return [
name
- for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ for name in (exp.column_table_names(join.args.get("on") or exp.true()))
if name != exclude
]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index b2ed062..d0e38cd 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -1,4 +1,6 @@
import sqlglot
+from sqlglot.optimizer.annotate_types import annotate_types
+from sqlglot.optimizer.canonicalize import canonicalize
from sqlglot.optimizer.eliminate_ctes import eliminate_ctes
from sqlglot.optimizer.eliminate_joins import eliminate_joins
from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
@@ -28,6 +30,8 @@ RULES = (
merge_subqueries,
eliminate_joins,
eliminate_ctes,
+ annotate_types,
+ canonicalize,
quote_identities,
)
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
index 6364f65..f92e5c3 100644
--- a/sqlglot/optimizer/pushdown_predicates.py
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -64,11 +64,11 @@ def pushdown_cnf(predicates, scope, scope_ref_count):
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
if isinstance(node, exp.Join):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
node.on(predicate, copy=False)
break
if isinstance(node, exp.Select):
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
node.where(replace_aliases(node, predicate), copy=False)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index 69fe2b8..e6e6dc9 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -382,9 +382,7 @@ class _Resolver:
raise OptimizeError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
- values_alias = source.expression.parent
- if hasattr(values_alias, "alias_column_names"):
- return values_alias.alias_column_names
+ return source.expression.alias_column_names
# Otherwise, if referencing another scope, return that scope's named selects
return source.expression.named_selects
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
index 0e467d3..5d8e0d9 100644
--- a/sqlglot/optimizer/qualify_tables.py
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -1,10 +1,11 @@
import itertools
from sqlglot import alias, exp
+from sqlglot.helper import csv_reader
from sqlglot.optimizer.scope import traverse_scope
-def qualify_tables(expression, db=None, catalog=None):
+def qualify_tables(expression, db=None, catalog=None, schema=None):
"""
Rewrite sqlglot AST to have fully qualified tables.
@@ -18,6 +19,7 @@ def qualify_tables(expression, db=None, catalog=None):
expression (sqlglot.Expression): expression to qualify
db (str): Database name
catalog (str): Catalog name
+ schema: A schema to populate
Returns:
sqlglot.Expression: qualified expression
"""
@@ -41,7 +43,7 @@ def qualify_tables(expression, db=None, catalog=None):
source.set("catalog", exp.to_identifier(catalog))
if not source.alias:
- source.replace(
+ source = source.replace(
alias(
source.copy(),
source.this if identifier else f"_q_{next(sequence)}",
@@ -49,4 +51,12 @@ def qualify_tables(expression, db=None, catalog=None):
)
)
+ if schema and isinstance(source.this, exp.ReadCSV):
+ with csv_reader(source.this) as reader:
+ header = next(reader)
+ columns = next(reader)
+ schema.add_table(
+ source, {k: type(v).__name__ for k, v in zip(header, columns)}
+ )
+
return expression
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index d759e86..c432c59 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -189,11 +189,11 @@ def absorb_and_eliminate(expression):
# absorb
if is_complement(b, aa):
- aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ aa.replace(exp.true() if kind == exp.And else exp.false())
elif is_complement(b, ab):
- ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ ab.replace(exp.true() if kind == exp.And else exp.false())
elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
- a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
+ a.replace(exp.false() if kind == exp.And else exp.true())
elif isinstance(b, kind):
# eliminate
rhs = b.unnest_operands()
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
index f41a84e..dbd680b 100644
--- a/sqlglot/optimizer/unnest_subqueries.py
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -169,7 +169,7 @@ def decorrelate(select, parent_select, external_columns, sequence):
select.parent.replace(alias)
for key, column, predicate in keys:
- predicate.replace(exp.TRUE)
+ predicate.replace(exp.true())
nested = exp.column(key_aliases[key], table_alias)
if key in group_by:
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index bbea0e5..5b93510 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -141,26 +141,29 @@ class Parser(metaclass=_Parser):
ID_VAR_TOKENS = {
TokenType.VAR,
- TokenType.ALTER,
TokenType.ALWAYS,
TokenType.ANTI,
TokenType.APPLY,
+ TokenType.AUTO_INCREMENT,
TokenType.BEGIN,
TokenType.BOTH,
TokenType.BUCKET,
TokenType.CACHE,
- TokenType.CALL,
+ TokenType.CASCADE,
TokenType.COLLATE,
+ TokenType.COMMAND,
TokenType.COMMIT,
TokenType.CONSTRAINT,
+ TokenType.CURRENT_TIME,
TokenType.DEFAULT,
TokenType.DELETE,
TokenType.DESCRIBE,
TokenType.DETERMINISTIC,
+ TokenType.DISTKEY,
+ TokenType.DISTSTYLE,
TokenType.EXECUTE,
TokenType.ENGINE,
TokenType.ESCAPE,
- TokenType.EXPLAIN,
TokenType.FALSE,
TokenType.FIRST,
TokenType.FOLLOWING,
@@ -182,7 +185,6 @@ class Parser(metaclass=_Parser):
TokenType.NATURAL,
TokenType.NEXT,
TokenType.ONLY,
- TokenType.OPTIMIZE,
TokenType.OPTIONS,
TokenType.ORDINALITY,
TokenType.PARTITIONED_BY,
@@ -199,6 +201,7 @@ class Parser(metaclass=_Parser):
TokenType.SEMI,
TokenType.SET,
TokenType.SHOW,
+ TokenType.SORTKEY,
TokenType.STABLE,
TokenType.STORED,
TokenType.TABLE,
@@ -207,7 +210,6 @@ class Parser(metaclass=_Parser):
TokenType.TRANSIENT,
TokenType.TOP,
TokenType.TRAILING,
- TokenType.TRUNCATE,
TokenType.TRUE,
TokenType.UNBOUNDED,
TokenType.UNIQUE,
@@ -217,6 +219,7 @@ class Parser(metaclass=_Parser):
TokenType.VOLATILE,
*SUBQUERY_PREDICATES,
*TYPE_TOKENS,
+ *NO_PAREN_FUNCTIONS,
}
TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY}
@@ -231,6 +234,7 @@ class Parser(metaclass=_Parser):
TokenType.FILTER,
TokenType.FIRST,
TokenType.FORMAT,
+ TokenType.IDENTIFIER,
TokenType.ISNULL,
TokenType.OFFSET,
TokenType.PRIMARY_KEY,
@@ -242,6 +246,7 @@ class Parser(metaclass=_Parser):
TokenType.RIGHT,
TokenType.DATE,
TokenType.DATETIME,
+ TokenType.TABLE,
TokenType.TIMESTAMP,
TokenType.TIMESTAMPTZ,
*TYPE_TOKENS,
@@ -277,6 +282,7 @@ class Parser(metaclass=_Parser):
TokenType.DASH: exp.Sub,
TokenType.PLUS: exp.Add,
TokenType.MOD: exp.Mod,
+ TokenType.COLLATE: exp.Collate,
}
FACTOR = {
@@ -391,7 +397,10 @@ class Parser(metaclass=_Parser):
TokenType.DELETE: lambda self: self._parse_delete(),
TokenType.CACHE: lambda self: self._parse_cache(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
- TokenType.USE: lambda self: self._parse_use(),
+ TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()),
+ TokenType.BEGIN: lambda self: self._parse_transaction(),
+ TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(),
+ TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(),
}
PRIMARY_PARSERS = {
@@ -402,7 +411,8 @@ class Parser(metaclass=_Parser):
exp.Literal, this=token.text, is_string=False
),
TokenType.STAR: lambda self, _: self.expression(
- exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
+ exp.Star,
+ **{"except": self._parse_except(), "replace": self._parse_replace()},
),
TokenType.NULL: lambda self, _: self.expression(exp.Null),
TokenType.TRUE: lambda self, _: self.expression(exp.Boolean, this=True),
@@ -446,6 +456,9 @@ class Parser(metaclass=_Parser):
TokenType.PARTITIONED_BY: lambda self: self._parse_partitioned_by(),
TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
TokenType.STORED: lambda self: self._parse_stored(),
+ TokenType.DISTKEY: lambda self: self._parse_distkey(),
+ TokenType.DISTSTYLE: lambda self: self._parse_diststyle(),
+ TokenType.SORTKEY: lambda self: self._parse_sortkey(),
TokenType.RETURNS: lambda self: self._parse_returns(),
TokenType.COLLATE: lambda self: self._parse_property_assignment(exp.CollateProperty),
TokenType.COMMENT: lambda self: self._parse_property_assignment(exp.SchemaCommentProperty),
@@ -471,7 +484,9 @@ class Parser(metaclass=_Parser):
}
CONSTRAINT_PARSERS = {
- TokenType.CHECK: lambda self: self._parse_check(),
+ TokenType.CHECK: lambda self: self.expression(
+ exp.Check, this=self._parse_wrapped(self._parse_conjunction)
+ ),
TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
TokenType.UNIQUE: lambda self: self._parse_unique(),
}
@@ -521,6 +536,8 @@ class Parser(metaclass=_Parser):
TokenType.SCHEMA,
}
+ TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"}
+
STRICT_CAST = True
__slots__ = (
@@ -740,6 +757,7 @@ class Parser(metaclass=_Parser):
kind=kind,
temporary=temporary,
materialized=materialized,
+ cascade=self._match(TokenType.CASCADE),
)
def _parse_exists(self, not_=False):
@@ -777,7 +795,11 @@ class Parser(metaclass=_Parser):
expression = self._parse_select_or_expression()
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index()
- elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW, TokenType.SCHEMA):
+ elif create_token.token_type in (
+ TokenType.TABLE,
+ TokenType.VIEW,
+ TokenType.SCHEMA,
+ ):
this = self._parse_table(schema=True)
properties = self._parse_properties()
if self._match(TokenType.ALIAS):
@@ -834,7 +856,38 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.FileFormatProperty,
this=exp.Literal.string("FORMAT"),
- value=exp.Literal.string(self._parse_var().name),
+ value=exp.Literal.string(self._parse_var_or_string().name),
+ )
+
+ def _parse_distkey(self):
+ self._match_l_paren()
+ this = exp.Literal.string("DISTKEY")
+ value = exp.Literal.string(self._parse_var().name)
+ self._match_r_paren()
+ return self.expression(
+ exp.DistKeyProperty,
+ this=this,
+ value=value,
+ )
+
+ def _parse_sortkey(self):
+ self._match_l_paren()
+ this = exp.Literal.string("SORTKEY")
+ value = exp.Literal.string(self._parse_var().name)
+ self._match_r_paren()
+ return self.expression(
+ exp.SortKeyProperty,
+ this=this,
+ value=value,
+ )
+
+ def _parse_diststyle(self):
+ this = exp.Literal.string("DISTSTYLE")
+ value = exp.Literal.string(self._parse_var().name)
+ return self.expression(
+ exp.DistStyleProperty,
+ this=this,
+ value=value,
)
def _parse_auto_increment(self):
@@ -842,7 +895,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.AutoIncrementProperty,
this=exp.Literal.string("AUTO_INCREMENT"),
- value=self._parse_var() or self._parse_number(),
+ value=self._parse_number(),
)
def _parse_schema_comment(self):
@@ -898,13 +951,10 @@ class Parser(metaclass=_Parser):
while True:
if self._match(TokenType.WITH):
- self._match_l_paren()
- properties.extend(self._parse_csv(lambda: self._parse_property()))
- self._match_r_paren()
+ properties.extend(self._parse_wrapped_csv(self._parse_property))
elif self._match(TokenType.PROPERTIES):
- self._match_l_paren()
properties.extend(
- self._parse_csv(
+ self._parse_wrapped_csv(
lambda: self.expression(
exp.AnonymousProperty,
this=self._parse_string(),
@@ -912,25 +962,24 @@ class Parser(metaclass=_Parser):
)
)
)
- self._match_r_paren()
else:
identified_property = self._parse_property()
if not identified_property:
break
properties.append(identified_property)
+
if properties:
return self.expression(exp.Properties, expressions=properties)
return None
def _parse_describe(self):
self._match(TokenType.TABLE)
-
return self.expression(exp.Describe, this=self._parse_id_var())
def _parse_insert(self):
overwrite = self._match(TokenType.OVERWRITE)
local = self._match(TokenType.LOCAL)
- if self._match_text("DIRECTORY"):
+ if self._match_text_seq("DIRECTORY"):
this = self.expression(
exp.Directory,
this=self._parse_var_or_string(),
@@ -954,27 +1003,27 @@ class Parser(metaclass=_Parser):
if not self._match_pair(TokenType.ROW, TokenType.FORMAT):
return None
- self._match_text("DELIMITED")
+ self._match_text_seq("DELIMITED")
kwargs = {}
- if self._match_text("FIELDS", "TERMINATED", "BY"):
+ if self._match_text_seq("FIELDS", "TERMINATED", "BY"):
kwargs["fields"] = self._parse_string()
- if self._match_text("ESCAPED", "BY"):
+ if self._match_text_seq("ESCAPED", "BY"):
kwargs["escaped"] = self._parse_string()
- if self._match_text("COLLECTION", "ITEMS", "TERMINATED", "BY"):
+ if self._match_text_seq("COLLECTION", "ITEMS", "TERMINATED", "BY"):
kwargs["collection_items"] = self._parse_string()
- if self._match_text("MAP", "KEYS", "TERMINATED", "BY"):
+ if self._match_text_seq("MAP", "KEYS", "TERMINATED", "BY"):
kwargs["map_keys"] = self._parse_string()
- if self._match_text("LINES", "TERMINATED", "BY"):
+ if self._match_text_seq("LINES", "TERMINATED", "BY"):
kwargs["lines"] = self._parse_string()
- if self._match_text("NULL", "DEFINED", "AS"):
+ if self._match_text_seq("NULL", "DEFINED", "AS"):
kwargs["null"] = self._parse_string()
return self.expression(exp.RowFormat, **kwargs)
def _parse_load_data(self):
local = self._match(TokenType.LOCAL)
- self._match_text("INPATH")
+ self._match_text_seq("INPATH")
inpath = self._parse_string()
overwrite = self._match(TokenType.OVERWRITE)
self._match_pair(TokenType.INTO, TokenType.TABLE)
@@ -986,8 +1035,8 @@ class Parser(metaclass=_Parser):
overwrite=overwrite,
inpath=inpath,
partition=self._parse_partition(),
- input_format=self._match_text("INPUTFORMAT") and self._parse_string(),
- serde=self._match_text("SERDE") and self._parse_string(),
+ input_format=self._match_text_seq("INPUTFORMAT") and self._parse_string(),
+ serde=self._match_text_seq("SERDE") and self._parse_string(),
)
def _parse_delete(self):
@@ -996,9 +1045,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.Delete,
this=self._parse_table(schema=True),
- using=self._parse_csv(
- lambda: self._match(TokenType.USING) and self._parse_table(schema=True)
- ),
+ using=self._parse_csv(lambda: self._match(TokenType.USING) and self._parse_table()),
where=self._parse_where(),
)
@@ -1029,12 +1076,7 @@ class Parser(metaclass=_Parser):
options = []
if self._match(TokenType.OPTIONS):
- self._match_l_paren()
- k = self._parse_string()
- self._match(TokenType.EQ)
- v = self._parse_string()
- options = [k, v]
- self._match_r_paren()
+ options = self._parse_wrapped_csv(self._parse_string, sep=TokenType.EQ)
self._match(TokenType.ALIAS)
return self.expression(
@@ -1050,27 +1092,13 @@ class Parser(metaclass=_Parser):
return None
def parse_values():
- key = self._parse_var()
- value = None
-
- if self._match(TokenType.EQ):
- value = self._parse_string()
-
- return exp.Property(this=key, value=value)
-
- self._match_l_paren()
- values = self._parse_csv(parse_values)
- self._match_r_paren()
+ props = self._parse_csv(self._parse_var_or_string, sep=TokenType.EQ)
+ return exp.Property(this=seq_get(props, 0), value=seq_get(props, 1))
- return self.expression(
- exp.Partition,
- this=values,
- )
+ return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self):
- self._match_l_paren()
- expressions = self._parse_csv(self._parse_conjunction)
- self._match_r_paren()
+ expressions = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Tuple, expressions=expressions)
def _parse_select(self, nested=False, table=False):
@@ -1124,10 +1152,11 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
this = self._parse_subquery(this)
elif self._match(TokenType.VALUES):
- this = self.expression(exp.Values, expressions=self._parse_csv(self._parse_value))
- alias = self._parse_table_alias()
- if alias:
- this = self.expression(exp.Subquery, this=this, alias=alias)
+ this = self.expression(
+ exp.Values,
+ expressions=self._parse_csv(self._parse_value),
+ alias=self._parse_table_alias(),
+ )
else:
this = None
@@ -1140,7 +1169,6 @@ class Parser(metaclass=_Parser):
recursive = self._match(TokenType.RECURSIVE)
expressions = []
-
while True:
expressions.append(self._parse_cte())
@@ -1149,11 +1177,7 @@ class Parser(metaclass=_Parser):
else:
self._match(TokenType.WITH)
- return self.expression(
- exp.With,
- expressions=expressions,
- recursive=recursive,
- )
+ return self.expression(exp.With, expressions=expressions, recursive=recursive)
def _parse_cte(self):
alias = self._parse_table_alias()
@@ -1163,13 +1187,9 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.ALIAS):
self.raise_error("Expected AS in CTE")
- self._match_l_paren()
- expression = self._parse_statement()
- self._match_r_paren()
-
return self.expression(
exp.CTE,
- this=expression,
+ this=self._parse_wrapped(self._parse_statement),
alias=alias,
)
@@ -1223,7 +1243,7 @@ class Parser(metaclass=_Parser):
def _parse_hint(self):
if self._match(TokenType.HINT):
hints = self._parse_csv(self._parse_function)
- if not self._match(TokenType.HINT):
+ if not self._match_pair(TokenType.STAR, TokenType.SLASH):
self.raise_error("Expected */ after HINT")
return self.expression(exp.Hint, expressions=hints)
return None
@@ -1259,26 +1279,18 @@ class Parser(metaclass=_Parser):
columns = self._parse_csv(self._parse_id_var)
elif self._match(TokenType.L_PAREN):
columns = self._parse_csv(self._parse_id_var)
- self._match(TokenType.R_PAREN)
+ self._match_r_paren()
expression = self.expression(
exp.Lateral,
this=this,
view=view,
outer=outer,
- alias=self.expression(
- exp.TableAlias,
- this=table_alias,
- columns=columns,
- ),
+ alias=self.expression(exp.TableAlias, this=table_alias, columns=columns),
)
if outer_apply or cross_apply:
- return self.expression(
- exp.Join,
- this=expression,
- side=None if cross_apply else "LEFT",
- )
+ return self.expression(exp.Join, this=expression, side=None if cross_apply else "LEFT")
return expression
@@ -1387,12 +1399,8 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.UNNEST):
return None
- self._match_l_paren()
- expressions = self._parse_csv(self._parse_column)
- self._match_r_paren()
-
+ expressions = self._parse_wrapped_csv(self._parse_column)
ordinality = bool(self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY))
-
alias = self._parse_table_alias()
if alias and self.unnest_column_only:
@@ -1402,10 +1410,7 @@ class Parser(metaclass=_Parser):
alias.set("this", None)
return self.expression(
- exp.Unnest,
- expressions=expressions,
- ordinality=ordinality,
- alias=alias,
+ exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
)
def _parse_derived_table_values(self):
@@ -1418,13 +1423,7 @@ class Parser(metaclass=_Parser):
if is_derived:
self._match_r_paren()
- alias = self._parse_table_alias()
-
- return self.expression(
- exp.Values,
- expressions=expressions,
- alias=alias,
- )
+ return self.expression(exp.Values, expressions=expressions, alias=self._parse_table_alias())
def _parse_table_sample(self):
if not self._match(TokenType.TABLE_SAMPLE):
@@ -1460,9 +1459,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
if self._match(TokenType.SEED):
- self._match_l_paren()
- seed = self._parse_number()
- self._match_r_paren()
+ seed = self._parse_wrapped(self._parse_number)
return self.expression(
exp.TableSample,
@@ -1513,12 +1510,7 @@ class Parser(metaclass=_Parser):
self._match_r_paren()
- return self.expression(
- exp.Pivot,
- expressions=expressions,
- field=field,
- unpivot=unpivot,
- )
+ return self.expression(exp.Pivot, expressions=expressions, field=field, unpivot=unpivot)
def _parse_where(self, skip_where_token=False):
if not skip_where_token and not self._match(TokenType.WHERE):
@@ -1539,11 +1531,7 @@ class Parser(metaclass=_Parser):
def _parse_grouping_sets(self):
if not self._match(TokenType.GROUPING_SETS):
return None
-
- self._match_l_paren()
- grouping_sets = self._parse_csv(self._parse_grouping_set)
- self._match_r_paren()
- return grouping_sets
+ return self._parse_wrapped_csv(self._parse_grouping_set)
def _parse_grouping_set(self):
if self._match(TokenType.L_PAREN):
@@ -1573,7 +1561,6 @@ class Parser(metaclass=_Parser):
def _parse_sort(self, token_type, exp_class):
if not self._match(token_type):
return None
-
return self.expression(exp_class, expressions=self._parse_csv(self._parse_ordered))
def _parse_ordered(self):
@@ -1602,9 +1589,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.TOP if top else TokenType.LIMIT):
limit_paren = self._match(TokenType.L_PAREN)
limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number())
+
if limit_paren:
- self._match(TokenType.R_PAREN)
+ self._match_r_paren()
+
return limit_exp
+
if self._match(TokenType.FETCH):
direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
direction = self._prev.text if direction else "FIRST"
@@ -1612,11 +1602,13 @@ class Parser(metaclass=_Parser):
self._match_set((TokenType.ROW, TokenType.ROWS))
self._match(TokenType.ONLY)
return self.expression(exp.Fetch, direction=direction, count=count)
+
return this
def _parse_offset(self, this=None):
if not self._match_set((TokenType.OFFSET, TokenType.COMMA)):
return this
+
count = self._parse_number()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)
@@ -1678,6 +1670,7 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.DISTINCT_FROM):
klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
return self.expression(klass, this=this, expression=self._parse_expression())
+
this = self.expression(
exp.Is,
this=this,
@@ -1754,11 +1747,7 @@ class Parser(metaclass=_Parser):
def _parse_type(self):
if self._match(TokenType.INTERVAL):
- return self.expression(
- exp.Interval,
- this=self._parse_term(),
- unit=self._parse_var(),
- )
+ return self.expression(exp.Interval, this=self._parse_term(), unit=self._parse_var())
index = self._index
type_token = self._parse_types(check_func=True)
@@ -1824,30 +1813,18 @@ class Parser(metaclass=_Parser):
value = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMPTZ,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMPTZ, expressions=expressions)
elif (
self._match(TokenType.WITH_LOCAL_TIME_ZONE) or type_token == TokenType.TIMESTAMPLTZ
):
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMPLTZ,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMPLTZ, expressions=expressions)
elif self._match(TokenType.WITHOUT_TIME_ZONE):
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMP,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
maybe_func = maybe_func and value is None
if value is None:
- value = exp.DataType(
- this=exp.DataType.Type.TIMESTAMP,
- expressions=expressions,
- )
+ value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions)
if maybe_func and check_func:
index2 = self._index
@@ -1872,6 +1849,7 @@ class Parser(metaclass=_Parser):
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
+
if not data_type:
return None
return self.expression(exp.StructKwarg, this=this, expression=data_type)
@@ -1879,7 +1857,6 @@ class Parser(metaclass=_Parser):
def _parse_at_time_zone(self, this):
if not self._match(TokenType.AT_TIME_ZONE):
return this
-
return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
def _parse_column(self):
@@ -1984,16 +1961,14 @@ class Parser(metaclass=_Parser):
else:
subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
- if subquery_predicate and self._curr.token_type in (
- TokenType.SELECT,
- TokenType.WITH,
- ):
+ if subquery_predicate and self._curr.token_type in (TokenType.SELECT, TokenType.WITH):
this = self.expression(subquery_predicate, this=self._parse_select())
self._match_r_paren()
return this
if functions is None:
functions = self.FUNCTIONS
+
function = functions.get(upper)
args = self._parse_csv(self._parse_lambda)
@@ -2014,6 +1989,7 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.L_PAREN):
return this
+
expressions = self._parse_csv(self._parse_udf_kwarg)
self._match_r_paren()
return self.expression(exp.UserDefinedFunction, this=this, expressions=expressions)
@@ -2021,25 +1997,19 @@ class Parser(metaclass=_Parser):
def _parse_introducer(self, token):
literal = self._parse_primary()
if literal:
- return self.expression(
- exp.Introducer,
- this=token.text,
- expression=literal,
- )
+ return self.expression(exp.Introducer, this=token.text, expression=literal)
return self.expression(exp.Identifier, this=token.text)
def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
+
if self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
- return self.expression(
- exp.SessionParameter,
- this=this,
- kind=kind,
- )
+
+ return self.expression(exp.SessionParameter, this=this, kind=kind)
def _parse_udf_kwarg(self):
this = self._parse_id_var()
@@ -2106,7 +2076,10 @@ class Parser(metaclass=_Parser):
return self.expression(exp.ColumnDef, this=this, kind=kind, constraints=constraints)
def _parse_column_constraint(self):
- this = None
+ this = self._parse_references()
+
+ if this:
+ return this
if self._match(TokenType.CONSTRAINT):
this = self._parse_id_var()
@@ -2114,13 +2087,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.AUTO_INCREMENT):
kind = exp.AutoIncrementColumnConstraint()
elif self._match(TokenType.CHECK):
- self._match_l_paren()
- kind = self.expression(exp.CheckColumnConstraint, this=self._parse_conjunction())
- self._match_r_paren()
+ constraint = self._parse_wrapped(self._parse_conjunction)
+ kind = self.expression(exp.CheckColumnConstraint, this=constraint)
elif self._match(TokenType.COLLATE):
kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
elif self._match(TokenType.DEFAULT):
- kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_field())
+ kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction())
elif self._match_pair(TokenType.NOT, TokenType.NULL):
kind = exp.NotNullColumnConstraint()
elif self._match(TokenType.SCHEMA_COMMENT):
@@ -2137,7 +2109,7 @@ class Parser(metaclass=_Parser):
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
else:
- return None
+ return this
return self.expression(exp.ColumnConstraint, this=this, kind=kind)
@@ -2159,37 +2131,29 @@ class Parser(metaclass=_Parser):
def _parse_unnamed_constraint(self):
if not self._match_set(self.CONSTRAINT_PARSERS):
return None
-
return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
- def _parse_check(self):
- self._match(TokenType.CHECK)
- self._match_l_paren()
- expression = self._parse_conjunction()
- self._match_r_paren()
-
- return self.expression(exp.Check, this=expression)
-
def _parse_unique(self):
- self._match(TokenType.UNIQUE)
- columns = self._parse_wrapped_id_vars()
-
- return self.expression(exp.Unique, expressions=columns)
+ return self.expression(exp.Unique, expressions=self._parse_wrapped_id_vars())
- def _parse_foreign_key(self):
- self._match(TokenType.FOREIGN_KEY)
-
- expressions = self._parse_wrapped_id_vars()
- reference = self._match(TokenType.REFERENCES) and self.expression(
+ def _parse_references(self):
+ if not self._match(TokenType.REFERENCES):
+ return None
+ return self.expression(
exp.Reference,
this=self._parse_id_var(),
expressions=self._parse_wrapped_id_vars(),
)
+
+ def _parse_foreign_key(self):
+ expressions = self._parse_wrapped_id_vars()
+ reference = self._parse_references()
options = {}
while self._match(TokenType.ON):
if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
self.raise_error("Expected DELETE or UPDATE")
+
kind = self._prev.text.lower()
if self._match(TokenType.NO_ACTION):
@@ -2200,6 +2164,7 @@ class Parser(metaclass=_Parser):
else:
self._advance()
action = self._prev.text.upper()
+
options[kind] = action
return self.expression(
@@ -2363,20 +2328,14 @@ class Parser(metaclass=_Parser):
def _parse_window(self, this, alias=False):
if self._match(TokenType.FILTER):
- self._match_l_paren()
- this = self.expression(exp.Filter, this=this, expression=self._parse_where())
- self._match_r_paren()
+ where = self._parse_wrapped(self._parse_where)
+ this = self.expression(exp.Filter, this=this, expression=where)
# T-SQL allows the OVER (...) syntax after WITHIN GROUP.
# https://learn.microsoft.com/en-us/sql/t-sql/functions/percentile-disc-transact-sql?view=sql-server-ver16
if self._match(TokenType.WITHIN_GROUP):
- self._match_l_paren()
- this = self.expression(
- exp.WithinGroup,
- this=this,
- expression=self._parse_order(),
- )
- self._match_r_paren()
+ order = self._parse_wrapped(self._parse_order)
+ this = self.expression(exp.WithinGroup, this=this, expression=order)
# SQL spec defines an optional [ { IGNORE | RESPECT } NULLS ] OVER
# Some dialects choose to implement and some do not.
@@ -2404,18 +2363,11 @@ class Parser(metaclass=_Parser):
return this
if not self._match(TokenType.L_PAREN):
- alias = self._parse_id_var(False)
-
- return self.expression(
- exp.Window,
- this=this,
- alias=alias,
- )
-
- partition = None
+ return self.expression(exp.Window, this=this, alias=self._parse_id_var(False))
alias = self._parse_id_var(False)
+ partition = None
if self._match(TokenType.PARTITION_BY):
partition = self._parse_csv(self._parse_conjunction)
@@ -2552,17 +2504,13 @@ class Parser(metaclass=_Parser):
def _parse_replace(self):
if not self._match(TokenType.REPLACE):
return None
+ return self._parse_wrapped_csv(lambda: self._parse_alias(self._parse_expression()))
- self._match_l_paren()
- columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression()))
- self._match_r_paren()
- return columns
-
- def _parse_csv(self, parse_method):
+ def _parse_csv(self, parse_method, sep=TokenType.COMMA):
parse_result = parse_method()
items = [parse_result] if parse_result is not None else []
- while self._match(TokenType.COMMA):
+ while self._match(sep):
if parse_result and self._prev_comment is not None:
parse_result.comment = self._prev_comment
@@ -2583,16 +2531,53 @@ class Parser(metaclass=_Parser):
return this
def _parse_wrapped_id_vars(self):
+ return self._parse_wrapped_csv(self._parse_id_var)
+
+ def _parse_wrapped_csv(self, parse_method, sep=TokenType.COMMA):
+ return self._parse_wrapped(lambda: self._parse_csv(parse_method, sep=sep))
+
+ def _parse_wrapped(self, parse_method):
self._match_l_paren()
- expressions = self._parse_csv(self._parse_id_var)
+ parse_result = parse_method()
self._match_r_paren()
- return expressions
+ return parse_result
def _parse_select_or_expression(self):
return self._parse_select() or self._parse_expression()
- def _parse_use(self):
- return self.expression(exp.Use, this=self._parse_id_var())
+ def _parse_transaction(self):
+ this = None
+ if self._match_texts(self.TRANSACTION_KIND):
+ this = self._prev.text
+
+ self._match_texts({"TRANSACTION", "WORK"})
+
+ modes = []
+ while True:
+ mode = []
+ while self._match(TokenType.VAR):
+ mode.append(self._prev.text)
+
+ if mode:
+ modes.append(" ".join(mode))
+ if not self._match(TokenType.COMMA):
+ break
+
+ return self.expression(exp.Transaction, this=this, modes=modes)
+
+ def _parse_commit_or_rollback(self):
+ savepoint = None
+ is_rollback = self._prev.token_type == TokenType.ROLLBACK
+
+ self._match_texts({"TRANSACTION", "WORK"})
+
+ if self._match_text_seq("TO"):
+ self._match_text_seq("SAVEPOINT")
+ savepoint = self._parse_id_var()
+
+ if is_rollback:
+ return self.expression(exp.Rollback, savepoint=savepoint)
+ return self.expression(exp.Commit)
def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
@@ -2675,7 +2660,13 @@ class Parser(metaclass=_Parser):
if expression and self._prev_comment:
expression.comment = self._prev_comment
- def _match_text(self, *texts):
+ def _match_texts(self, texts):
+ if self._curr and self._curr.text.upper() in texts:
+ self._advance()
+ return True
+ return False
+
+ def _match_text_seq(self, *texts):
index = self._index
for text in texts:
if self._curr and self._curr.text.upper() == text:
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
index cd1de5e..51db2d4 100644
--- a/sqlglot/planner.py
+++ b/sqlglot/planner.py
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
import itertools
import math
+import typing as t
from sqlglot import alias, exp
from sqlglot.errors import UnsupportedError
@@ -7,15 +10,15 @@ from sqlglot.optimizer.eliminate_joins import join_condition
class Plan:
- def __init__(self, expression):
- self.expression = expression
+ def __init__(self, expression: exp.Expression) -> None:
+ self.expression = expression.copy()
self.root = Step.from_expression(self.expression)
- self._dag = {}
+ self._dag: t.Dict[Step, t.Set[Step]] = {}
@property
- def dag(self):
+ def dag(self) -> t.Dict[Step, t.Set[Step]]:
if not self._dag:
- dag = {}
+ dag: t.Dict[Step, t.Set[Step]] = {}
nodes = {self.root}
while nodes:
@@ -29,32 +32,64 @@ class Plan:
return self._dag
@property
- def leaves(self):
+ def leaves(self) -> t.Generator[Step, None, None]:
return (node for node, deps in self.dag.items() if not deps)
+ def __repr__(self) -> str:
+ return f"Plan\n----\n{repr(self.root)}"
+
class Step:
@classmethod
- def from_expression(cls, expression, ctes=None):
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
"""
- Build a DAG of Steps from a SQL expression.
-
- Giving an expression like:
-
- SELECT x.a, SUM(x.b)
- FROM x
- JOIN y
- ON x.a = y.a
+ Builds a DAG of Steps from a SQL expression so that it's easier to execute in an engine.
+ Note: the expression's tables and subqueries must be aliased for this method to work. For
+ example, given the following expression:
+
+ SELECT
+ x.a,
+ SUM(x.b)
+ FROM x AS x
+ JOIN y AS y
+ ON x.a = y.a
GROUP BY x.a
- Transform it into a DAG of the form:
-
- Aggregate(x.a, SUM(x.b))
- Join(y)
- Scan(x)
- Scan(y)
-
- This can then more easily be executed on by an engine.
+ the following DAG is produced (the expression IDs might differ per execution):
+
+ - Aggregate: x (4347984624)
+ Context:
+ Aggregations:
+ - SUM(x.b)
+ Group:
+ - x.a
+ Projections:
+ - x.a
+ - "x".""
+ Dependencies:
+ - Join: x (4347985296)
+ Context:
+ y:
+ On: x.a = y.a
+ Projections:
+ Dependencies:
+ - Scan: x (4347983136)
+ Context:
+ Source: x AS x
+ Projections:
+ - Scan: y (4343416624)
+ Context:
+ Source: y AS y
+ Projections:
+
+ Args:
+ expression: the expression to build the DAG from.
+ ctes: a dictionary that maps CTEs to their corresponding Step DAG by name.
+
+ Returns:
+ A Step DAG corresponding to `expression`.
"""
ctes = ctes or {}
with_ = expression.args.get("with")
@@ -65,11 +100,11 @@ class Step:
for cte in with_.expressions:
step = Step.from_expression(cte.this, ctes)
step.name = cte.alias
- ctes[step.name] = step
+ ctes[step.name] = step # type: ignore
from_ = expression.args.get("from")
- if from_:
+ if isinstance(expression, exp.Select) and from_:
from_ = from_.expressions
if len(from_) > 1:
raise UnsupportedError(
@@ -77,8 +112,10 @@ class Step:
)
step = Scan.from_expression(from_[0], ctes)
+ elif isinstance(expression, exp.Union):
+ step = SetOperation.from_expression(expression, ctes)
else:
- raise UnsupportedError("Static selects are unsupported.")
+ step = Scan()
joins = expression.args.get("joins")
@@ -115,7 +152,7 @@ class Step:
group = expression.args.get("group")
- if group:
+ if group or aggregations:
aggregate = Aggregate()
aggregate.source = step.name
aggregate.name = step.name
@@ -123,7 +160,15 @@ class Step:
alias(operand, alias_) for operand, alias_ in operands.items()
)
aggregate.aggregations = aggregations
- aggregate.group = group.expressions
+ # give aggregates names and replace projections with references to them
+ aggregate.group = {
+ f"_g{i}": e for i, e in enumerate(group.expressions if group else [])
+ }
+ for projection in projections:
+ for i, e in aggregate.group.items():
+ for child, _, _ in projection.walk():
+ if child == e:
+ child.replace(exp.column(i, step.name))
aggregate.add_dependency(step)
step = aggregate
@@ -150,22 +195,22 @@ class Step:
return step
- def __init__(self):
- self.name = None
- self.dependencies = set()
- self.dependents = set()
- self.projections = []
- self.limit = math.inf
- self.condition = None
+ def __init__(self) -> None:
+ self.name: t.Optional[str] = None
+ self.dependencies: t.Set[Step] = set()
+ self.dependents: t.Set[Step] = set()
+ self.projections: t.Sequence[exp.Expression] = []
+ self.limit: float = math.inf
+ self.condition: t.Optional[exp.Expression] = None
- def add_dependency(self, dependency):
+ def add_dependency(self, dependency: Step) -> None:
self.dependencies.add(dependency)
dependency.dependents.add(self)
- def __repr__(self):
+ def __repr__(self) -> str:
return self.to_s()
- def to_s(self, level=0):
+ def to_s(self, level: int = 0) -> str:
indent = " " * level
nested = f"{indent} "
@@ -175,7 +220,7 @@ class Step:
context = [f"{nested}Context:"] + context
lines = [
- f"{indent}- {self.__class__.__name__}: {self.name}",
+ f"{indent}- {self.id}",
*context,
f"{nested}Projections:",
]
@@ -193,13 +238,25 @@ class Step:
return "\n".join(lines)
- def _to_s(self, _indent):
+ @property
+ def type_name(self) -> str:
+ return self.__class__.__name__
+
+ @property
+ def id(self) -> str:
+ name = self.name
+ name = f" {name}" if name else ""
+ return f"{self.type_name}:{name} ({id(self)})"
+
+ def _to_s(self, _indent: str) -> t.List[str]:
return []
class Scan(Step):
@classmethod
- def from_expression(cls, expression, ctes=None):
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
table = expression
alias_ = expression.alias
@@ -217,26 +274,24 @@ class Scan(Step):
step = Scan()
step.name = alias_
step.source = expression
- if table.name in ctes:
+ if ctes and table.name in ctes:
step.add_dependency(ctes[table.name])
return step
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.source = None
-
- def _to_s(self, indent):
- return [f"{indent}Source: {self.source.sql()}"]
+ self.source: t.Optional[exp.Expression] = None
-
-class Write(Step):
- pass
+ def _to_s(self, indent: str) -> t.List[str]:
+ return [f"{indent}Source: {self.source.sql() if self.source else '-static-'}"] # type: ignore
class Join(Step):
@classmethod
- def from_joins(cls, joins, ctes=None):
+ def from_joins(
+ cls, joins: t.Iterable[exp.Join], ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
step = Join()
for join in joins:
@@ -252,28 +307,28 @@ class Join(Step):
return step
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.joins = {}
+ self.joins: t.Dict[str, t.Dict[str, t.List[str] | exp.Expression]] = {}
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = []
for name, join in self.joins.items():
lines.append(f"{indent}{name}: {join['side']}")
if join.get("condition"):
- lines.append(f"{indent}On: {join['condition'].sql()}")
+ lines.append(f"{indent}On: {join['condition'].sql()}") # type: ignore
return lines
class Aggregate(Step):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
- self.aggregations = []
- self.operands = []
- self.group = []
- self.source = None
+ self.aggregations: t.List[exp.Expression] = []
+ self.operands: t.Tuple[exp.Expression, ...] = ()
+ self.group: t.Dict[str, exp.Expression] = {}
+ self.source: t.Optional[str] = None
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Aggregations:"]
for expression in self.aggregations:
@@ -281,7 +336,7 @@ class Aggregate(Step):
if self.group:
lines.append(f"{indent}Group:")
- for expression in self.group:
+ for expression in self.group.values():
lines.append(f"{indent} - {expression.sql()}")
if self.operands:
lines.append(f"{indent}Operands:")
@@ -292,14 +347,56 @@ class Aggregate(Step):
class Sort(Step):
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self.key = None
- def _to_s(self, indent):
+ def _to_s(self, indent: str) -> t.List[str]:
lines = [f"{indent}Key:"]
- for expression in self.key:
+ for expression in self.key: # type: ignore
lines.append(f"{indent} - {expression.sql()}")
return lines
+
+
+class SetOperation(Step):
+ def __init__(
+ self,
+ op: t.Type[exp.Expression],
+ left: str | None,
+ right: str | None,
+ distinct: bool = False,
+ ) -> None:
+ super().__init__()
+ self.op = op
+ self.left = left
+ self.right = right
+ self.distinct = distinct
+
+ @classmethod
+ def from_expression(
+ cls, expression: exp.Expression, ctes: t.Optional[t.Dict[str, Step]] = None
+ ) -> Step:
+ assert isinstance(expression, exp.Union)
+ left = Step.from_expression(expression.left, ctes)
+ right = Step.from_expression(expression.right, ctes)
+ step = cls(
+ op=expression.__class__,
+ left=left.name,
+ right=right.name,
+ distinct=expression.args.get("distinct"),
+ )
+ step.add_dependency(left)
+ step.add_dependency(right)
+ return step
+
+ def _to_s(self, indent: str) -> t.List[str]:
+ lines = []
+ if self.distinct:
+ lines.append(f"{indent}Distinct: {self.distinct}")
+ return lines
+
+ @property
+ def type_name(self) -> str:
+ return self.op.__name__
diff --git a/sqlglot/schema.py b/sqlglot/schema.py
index fcf7291..f6f303b 100644
--- a/sqlglot/schema.py
+++ b/sqlglot/schema.py
@@ -5,7 +5,7 @@ import typing as t
from sqlglot import expressions as exp
from sqlglot.errors import SchemaError
-from sqlglot.helper import csv_reader
+from sqlglot.helper import dict_depth
from sqlglot.trie import in_trie, new_trie
if t.TYPE_CHECKING:
@@ -15,6 +15,8 @@ if t.TYPE_CHECKING:
TABLE_ARGS = ("this", "db", "catalog")
+T = t.TypeVar("T")
+
class Schema(abc.ABC):
"""Abstract base class for database schemas"""
@@ -57,8 +59,81 @@ class Schema(abc.ABC):
The resulting column type.
"""
+ @property
+ def supported_table_args(self) -> t.Tuple[str, ...]:
+ """
+ Table arguments this schema support, e.g. `("this", "db", "catalog")`
+ """
+ raise NotImplementedError
+
+
+class AbstractMappingSchema(t.Generic[T]):
+ def __init__(
+ self,
+ mapping: dict | None = None,
+ ) -> None:
+ self.mapping = mapping or {}
+ self.mapping_trie = self._build_trie(self.mapping)
+ self._supported_table_args: t.Tuple[str, ...] = tuple()
+
+ def _build_trie(self, schema: t.Dict) -> t.Dict:
+ return new_trie(tuple(reversed(t)) for t in flatten_schema(schema, depth=self._depth()))
+
+ def _depth(self) -> int:
+ return dict_depth(self.mapping)
+
+ @property
+ def supported_table_args(self) -> t.Tuple[str, ...]:
+ if not self._supported_table_args and self.mapping:
+ depth = self._depth()
+
+ if not depth: # None
+ self._supported_table_args = tuple()
+ elif 1 <= depth <= 3:
+ self._supported_table_args = TABLE_ARGS[:depth]
+ else:
+ raise SchemaError(f"Invalid mapping shape. Depth: {depth}")
+
+ return self._supported_table_args
+
+ def table_parts(self, table: exp.Table) -> t.List[str]:
+ if isinstance(table.this, exp.ReadCSV):
+ return [table.this.name]
+ return [table.text(part) for part in TABLE_ARGS if table.text(part)]
+
+ def find(
+ self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
+ ) -> t.Optional[T]:
+ parts = self.table_parts(table)[0 : len(self.supported_table_args)]
+ value, trie = in_trie(self.mapping_trie if trie is None else trie, parts)
+
+ if value == 0:
+ if raise_on_missing:
+ raise SchemaError(f"Cannot find mapping for {table}.")
+ else:
+ return None
+ elif value == 1:
+ possibilities = flatten_schema(trie, depth=dict_depth(trie) - 1)
+ if len(possibilities) == 1:
+ parts.extend(possibilities[0])
+ else:
+ message = ", ".join(".".join(parts) for parts in possibilities)
+ if raise_on_missing:
+ raise SchemaError(f"Ambiguous mapping for {table}: {message}.")
+ return None
+ return self._nested_get(parts, raise_on_missing=raise_on_missing)
-class MappingSchema(Schema):
+ def _nested_get(
+ self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
+ ) -> t.Optional[t.Any]:
+ return _nested_get(
+ d or self.mapping,
+ *zip(self.supported_table_args, reversed(parts)),
+ raise_on_missing=raise_on_missing,
+ )
+
+
+class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema):
"""
Schema based on a nested mapping.
@@ -82,17 +157,17 @@ class MappingSchema(Schema):
visible: t.Optional[t.Dict] = None,
dialect: t.Optional[str] = None,
) -> None:
- self.schema = schema or {}
+ super().__init__(schema)
self.visible = visible or {}
- self.schema_trie = self._build_trie(self.schema)
self.dialect = dialect
- self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {}
- self._supported_table_args: t.Tuple[str, ...] = tuple()
+ self._type_mapping_cache: t.Dict[str, exp.DataType.Type] = {
+ "STR": exp.DataType.Type.TEXT,
+ }
@classmethod
def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema:
return MappingSchema(
- schema=mapping_schema.schema,
+ schema=mapping_schema.mapping,
visible=mapping_schema.visible,
dialect=mapping_schema.dialect,
)
@@ -100,27 +175,13 @@ class MappingSchema(Schema):
def copy(self, **kwargs) -> MappingSchema:
return MappingSchema(
**{ # type: ignore
- "schema": self.schema.copy(),
+ "schema": self.mapping.copy(),
"visible": self.visible.copy(),
"dialect": self.dialect,
**kwargs,
}
)
- @property
- def supported_table_args(self):
- if not self._supported_table_args and self.schema:
- depth = _dict_depth(self.schema)
-
- if not depth or depth == 1: # {}
- self._supported_table_args = tuple()
- elif 2 <= depth <= 4:
- self._supported_table_args = TABLE_ARGS[: depth - 1]
- else:
- raise SchemaError(f"Invalid schema shape. Depth: {depth}")
-
- return self._supported_table_args
-
def add_table(
self, table: exp.Table | str, column_mapping: t.Optional[ColumnMapping] = None
) -> None:
@@ -133,17 +194,21 @@ class MappingSchema(Schema):
"""
table_ = self._ensure_table(table)
column_mapping = ensure_column_mapping(column_mapping)
- schema = self.find_schema(table_, raise_on_missing=False)
+ schema = self.find(table_, raise_on_missing=False)
if schema and not column_mapping:
return
_nested_set(
- self.schema,
+ self.mapping,
list(reversed(self.table_parts(table_))),
column_mapping,
)
- self.schema_trie = self._build_trie(self.schema)
+ self.mapping_trie = self._build_trie(self.mapping)
+
+ def _depth(self) -> int:
+ # The columns themselves are a mapping, but we don't want to include those
+ return super()._depth() - 1
def _ensure_table(self, table: exp.Table | str) -> exp.Table:
table_ = exp.to_table(table)
@@ -153,16 +218,9 @@ class MappingSchema(Schema):
return table_
- def table_parts(self, table: exp.Table) -> t.List[str]:
- return [table.text(part) for part in TABLE_ARGS if table.text(part)]
-
def column_names(self, table: exp.Table | str, only_visible: bool = False) -> t.List[str]:
table_ = self._ensure_table(table)
-
- if not isinstance(table_.this, exp.Identifier):
- return fs_get(table) # type: ignore
-
- schema = self.find_schema(table_)
+ schema = self.find(table_)
if schema is None:
raise SchemaError(f"Could not find table schema {table}")
@@ -173,36 +231,13 @@ class MappingSchema(Schema):
visible = self._nested_get(self.table_parts(table_), self.visible)
return [col for col in schema if col in visible] # type: ignore
- def find_schema(
- self, table: exp.Table, trie: t.Optional[t.Dict] = None, raise_on_missing: bool = True
- ) -> t.Optional[t.Dict[str, str]]:
- parts = self.table_parts(table)[0 : len(self.supported_table_args)]
- value, trie = in_trie(self.schema_trie if trie is None else trie, parts)
-
- if value == 0:
- if raise_on_missing:
- raise SchemaError(f"Cannot find schema for {table}.")
- else:
- return None
- elif value == 1:
- possibilities = flatten_schema(trie)
- if len(possibilities) == 1:
- parts.extend(possibilities[0])
- else:
- message = ", ".join(".".join(parts) for parts in possibilities)
- if raise_on_missing:
- raise SchemaError(f"Ambiguous schema for {table}: {message}.")
- return None
-
- return self._nested_get(parts, raise_on_missing=raise_on_missing)
-
def get_column_type(
self, table: exp.Table | str, column: exp.Column | str
) -> exp.DataType.Type:
column_name = column if isinstance(column, str) else column.name
table_ = exp.to_table(table)
if table_:
- table_schema = self.find_schema(table_)
+ table_schema = self.find(table_)
schema_type = table_schema.get(column_name).upper() # type: ignore
return self._convert_type(schema_type)
raise SchemaError(f"Could not convert table '{table}'")
@@ -228,18 +263,6 @@ class MappingSchema(Schema):
return self._type_mapping_cache[schema_type]
- def _build_trie(self, schema: t.Dict):
- return new_trie(tuple(reversed(t)) for t in flatten_schema(schema))
-
- def _nested_get(
- self, parts: t.Sequence[str], d: t.Optional[t.Dict] = None, raise_on_missing=True
- ) -> t.Optional[t.Any]:
- return _nested_get(
- d or self.schema,
- *zip(self.supported_table_args, reversed(parts)),
- raise_on_missing=raise_on_missing,
- )
-
def ensure_schema(schema: t.Any) -> Schema:
if isinstance(schema, Schema):
@@ -267,29 +290,20 @@ def ensure_column_mapping(mapping: t.Optional[ColumnMapping]):
raise ValueError(f"Invalid mapping provided: {type(mapping)}")
-def flatten_schema(schema: t.Dict, keys: t.Optional[t.List[str]] = None) -> t.List[t.List[str]]:
+def flatten_schema(
+ schema: t.Dict, depth: int, keys: t.Optional[t.List[str]] = None
+) -> t.List[t.List[str]]:
tables = []
keys = keys or []
- depth = _dict_depth(schema)
for k, v in schema.items():
- if depth >= 3:
- tables.extend(flatten_schema(v, keys + [k]))
- elif depth == 2:
+ if depth >= 2:
+ tables.extend(flatten_schema(v, depth - 1, keys + [k]))
+ elif depth == 1:
tables.append(keys + [k])
return tables
-def fs_get(table: exp.Table) -> t.List[str]:
- name = table.this.name
-
- if name.upper() == "READ_CSV":
- with csv_reader(table) as reader:
- return next(reader)
-
- raise ValueError(f"Cannot read schema for {table}")
-
-
def _nested_get(
d: t.Dict, *path: t.Tuple[str, str], raise_on_missing: bool = True
) -> t.Optional[t.Any]:
@@ -310,7 +324,7 @@ def _nested_get(
if d is None:
if raise_on_missing:
name = "table" if name == "this" else name
- raise ValueError(f"Unknown {name}")
+ raise ValueError(f"Unknown {name}: {key}")
return None
return d
@@ -350,34 +364,3 @@ def _nested_set(d: t.Dict, keys: t.List[str], value: t.Any) -> t.Dict:
subd[keys[-1]] = value
return d
-
-
-def _dict_depth(d: t.Dict) -> int:
- """
- Get the nesting depth of a dictionary.
-
- For example:
- >>> _dict_depth(None)
- 0
- >>> _dict_depth({})
- 1
- >>> _dict_depth({"a": "b"})
- 1
- >>> _dict_depth({"a": {}})
- 2
- >>> _dict_depth({"a": {"b": {}}})
- 3
-
- Args:
- d (dict): dictionary
- Returns:
- int: depth
- """
- try:
- return 1 + _dict_depth(next(iter(d.values())))
- except AttributeError:
- # d doesn't have attribute "values"
- return 0
- except StopIteration:
- # d.values() returns an empty sequence
- return 1
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 95d84d6..ec8cd91 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -105,12 +105,9 @@ class TokenType(AutoName):
OBJECT = auto()
# keywords
- ADD_FILE = auto()
ALIAS = auto()
ALWAYS = auto()
ALL = auto()
- ALTER = auto()
- ANALYZE = auto()
ANTI = auto()
ANY = auto()
APPLY = auto()
@@ -124,14 +121,14 @@ class TokenType(AutoName):
BUCKET = auto()
BY_DEFAULT = auto()
CACHE = auto()
- CALL = auto()
+ CASCADE = auto()
CASE = auto()
CHARACTER_SET = auto()
CHECK = auto()
CLUSTER_BY = auto()
COLLATE = auto()
+ COMMAND = auto()
COMMENT = auto()
- COMMENT_ON = auto()
COMMIT = auto()
CONSTRAINT = auto()
CREATE = auto()
@@ -149,7 +146,9 @@ class TokenType(AutoName):
DETERMINISTIC = auto()
DISTINCT = auto()
DISTINCT_FROM = auto()
+ DISTKEY = auto()
DISTRIBUTE_BY = auto()
+ DISTSTYLE = auto()
DIV = auto()
DROP = auto()
ELSE = auto()
@@ -159,7 +158,6 @@ class TokenType(AutoName):
EXCEPT = auto()
EXECUTE = auto()
EXISTS = auto()
- EXPLAIN = auto()
FALSE = auto()
FETCH = auto()
FILTER = auto()
@@ -216,7 +214,6 @@ class TokenType(AutoName):
OFFSET = auto()
ON = auto()
ONLY = auto()
- OPTIMIZE = auto()
OPTIONS = auto()
ORDER_BY = auto()
ORDERED = auto()
@@ -258,6 +255,7 @@ class TokenType(AutoName):
SHOW = auto()
SIMILAR_TO = auto()
SOME = auto()
+ SORTKEY = auto()
SORT_BY = auto()
STABLE = auto()
STORED = auto()
@@ -268,9 +266,8 @@ class TokenType(AutoName):
TRANSIENT = auto()
TOP = auto()
THEN = auto()
- TRUE = auto()
TRAILING = auto()
- TRUNCATE = auto()
+ TRUE = auto()
UNBOUNDED = auto()
UNCACHE = auto()
UNION = auto()
@@ -280,7 +277,6 @@ class TokenType(AutoName):
USE = auto()
USING = auto()
VALUES = auto()
- VACUUM = auto()
VIEW = auto()
VOLATILE = auto()
WHEN = auto()
@@ -420,7 +416,6 @@ class Tokenizer(metaclass=_Tokenizer):
KEYWORDS = {
"/*+": TokenType.HINT,
- "*/": TokenType.HINT,
"==": TokenType.EQ,
"::": TokenType.DCOLON,
"||": TokenType.DPIPE,
@@ -435,15 +430,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>": TokenType.HASH_ARROW,
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
- "ADD ARCHIVE": TokenType.ADD_FILE,
- "ADD ARCHIVES": TokenType.ADD_FILE,
- "ADD FILE": TokenType.ADD_FILE,
- "ADD FILES": TokenType.ADD_FILE,
- "ADD JAR": TokenType.ADD_FILE,
- "ADD JARS": TokenType.ADD_FILE,
"ALL": TokenType.ALL,
- "ALTER": TokenType.ALTER,
- "ANALYZE": TokenType.ANALYZE,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
@@ -455,10 +442,10 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
- "CALL": TokenType.CALL,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
+ "CASCADE": TokenType.CASCADE,
"CHARACTER SET": TokenType.CHARACTER_SET,
"CHECK": TokenType.CHECK,
"CLUSTER BY": TokenType.CLUSTER_BY,
@@ -479,7 +466,9 @@ class Tokenizer(metaclass=_Tokenizer):
"DETERMINISTIC": TokenType.DETERMINISTIC,
"DISTINCT": TokenType.DISTINCT,
"DISTINCT FROM": TokenType.DISTINCT_FROM,
+ "DISTKEY": TokenType.DISTKEY,
"DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
+ "DISTSTYLE": TokenType.DISTSTYLE,
"DIV": TokenType.DIV,
"DROP": TokenType.DROP,
"ELSE": TokenType.ELSE,
@@ -489,7 +478,6 @@ class Tokenizer(metaclass=_Tokenizer):
"EXCEPT": TokenType.EXCEPT,
"EXECUTE": TokenType.EXECUTE,
"EXISTS": TokenType.EXISTS,
- "EXPLAIN": TokenType.EXPLAIN,
"FALSE": TokenType.FALSE,
"FETCH": TokenType.FETCH,
"FILTER": TokenType.FILTER,
@@ -541,7 +529,6 @@ class Tokenizer(metaclass=_Tokenizer):
"OFFSET": TokenType.OFFSET,
"ON": TokenType.ON,
"ONLY": TokenType.ONLY,
- "OPTIMIZE": TokenType.OPTIMIZE,
"OPTIONS": TokenType.OPTIONS,
"OR": TokenType.OR,
"ORDER BY": TokenType.ORDER_BY,
@@ -579,6 +566,7 @@ class Tokenizer(metaclass=_Tokenizer):
"SET": TokenType.SET,
"SHOW": TokenType.SHOW,
"SOME": TokenType.SOME,
+ "SORTKEY": TokenType.SORTKEY,
"SORT BY": TokenType.SORT_BY,
"STABLE": TokenType.STABLE,
"STORED": TokenType.STORED,
@@ -592,7 +580,6 @@ class Tokenizer(metaclass=_Tokenizer):
"THEN": TokenType.THEN,
"TRUE": TokenType.TRUE,
"TRAILING": TokenType.TRAILING,
- "TRUNCATE": TokenType.TRUNCATE,
"UNBOUNDED": TokenType.UNBOUNDED,
"UNION": TokenType.UNION,
"UNPIVOT": TokenType.UNPIVOT,
@@ -600,7 +587,6 @@ class Tokenizer(metaclass=_Tokenizer):
"UPDATE": TokenType.UPDATE,
"USE": TokenType.USE,
"USING": TokenType.USING,
- "VACUUM": TokenType.VACUUM,
"VALUES": TokenType.VALUES,
"VIEW": TokenType.VIEW,
"VOLATILE": TokenType.VOLATILE,
@@ -659,6 +645,14 @@ class Tokenizer(metaclass=_Tokenizer):
"UNIQUE": TokenType.UNIQUE,
"STRUCT": TokenType.STRUCT,
"VARIANT": TokenType.VARIANT,
+ "ALTER": TokenType.COMMAND,
+ "ANALYZE": TokenType.COMMAND,
+ "CALL": TokenType.COMMAND,
+ "EXPLAIN": TokenType.COMMAND,
+ "OPTIMIZE": TokenType.COMMAND,
+ "PREPARE": TokenType.COMMAND,
+ "TRUNCATE": TokenType.COMMAND,
+ "VACUUM": TokenType.COMMAND,
}
WHITE_SPACE = {
@@ -670,20 +664,11 @@ class Tokenizer(metaclass=_Tokenizer):
}
COMMANDS = {
- TokenType.ALTER,
- TokenType.ADD_FILE,
- TokenType.ANALYZE,
- TokenType.BEGIN,
- TokenType.CALL,
- TokenType.COMMENT_ON,
- TokenType.COMMIT,
- TokenType.EXPLAIN,
- TokenType.OPTIMIZE,
+ TokenType.COMMAND,
+ TokenType.EXECUTE,
+ TokenType.FETCH,
TokenType.SET,
TokenType.SHOW,
- TokenType.TRUNCATE,
- TokenType.VACUUM,
- TokenType.ROLLBACK,
}
# handle numeric literals like in hive (3L = BIGINT)
@@ -885,6 +870,7 @@ class Tokenizer(metaclass=_Tokenizer):
if comment_start_line == self._prev_token_line:
if self._prev_token_comment is None:
self.tokens[-1].comment = self._comment
+ self._prev_token_comment = self._comment
self._comment = None