summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-09-15 16:46:17 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-09-15 16:46:17 +0000
commit28cc22419e32a65fea2d1678400265b8cabc3aff (patch)
treeff9ac1991fd48490b21ef6aa9015a347a165e2d9 /sqlglot
parentInitial commit. (diff)
downloadsqlglot-28cc22419e32a65fea2d1678400265b8cabc3aff.tar.xz
sqlglot-28cc22419e32a65fea2d1678400265b8cabc3aff.zip
Adding upstream version 6.0.4.upstream/6.0.4
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/__init__.py96
-rw-r--r--sqlglot/__main__.py69
-rw-r--r--sqlglot/dialects/__init__.py15
-rw-r--r--sqlglot/dialects/bigquery.py128
-rw-r--r--sqlglot/dialects/clickhouse.py48
-rw-r--r--sqlglot/dialects/dialect.py268
-rw-r--r--sqlglot/dialects/duckdb.py156
-rw-r--r--sqlglot/dialects/hive.py304
-rw-r--r--sqlglot/dialects/mysql.py163
-rw-r--r--sqlglot/dialects/oracle.py63
-rw-r--r--sqlglot/dialects/postgres.py109
-rw-r--r--sqlglot/dialects/presto.py216
-rw-r--r--sqlglot/dialects/snowflake.py145
-rw-r--r--sqlglot/dialects/spark.py106
-rw-r--r--sqlglot/dialects/sqlite.py63
-rw-r--r--sqlglot/dialects/starrocks.py12
-rw-r--r--sqlglot/dialects/tableau.py37
-rw-r--r--sqlglot/dialects/trino.py10
-rw-r--r--sqlglot/diff.py314
-rw-r--r--sqlglot/errors.py38
-rw-r--r--sqlglot/executor/__init__.py39
-rw-r--r--sqlglot/executor/context.py68
-rw-r--r--sqlglot/executor/env.py32
-rw-r--r--sqlglot/executor/python.py360
-rw-r--r--sqlglot/executor/table.py81
-rw-r--r--sqlglot/expressions.py2945
-rw-r--r--sqlglot/generator.py1124
-rw-r--r--sqlglot/helper.py123
-rw-r--r--sqlglot/optimizer/__init__.py2
-rw-r--r--sqlglot/optimizer/eliminate_subqueries.py48
-rw-r--r--sqlglot/optimizer/expand_multi_table_selects.py16
-rw-r--r--sqlglot/optimizer/isolate_table_selects.py31
-rw-r--r--sqlglot/optimizer/normalize.py136
-rw-r--r--sqlglot/optimizer/optimize_joins.py75
-rw-r--r--sqlglot/optimizer/optimizer.py43
-rw-r--r--sqlglot/optimizer/pushdown_predicates.py176
-rw-r--r--sqlglot/optimizer/pushdown_projections.py85
-rw-r--r--sqlglot/optimizer/qualify_columns.py422
-rw-r--r--sqlglot/optimizer/qualify_tables.py54
-rw-r--r--sqlglot/optimizer/quote_identities.py25
-rw-r--r--sqlglot/optimizer/schema.py129
-rw-r--r--sqlglot/optimizer/scope.py438
-rw-r--r--sqlglot/optimizer/simplify.py383
-rw-r--r--sqlglot/optimizer/unnest_subqueries.py220
-rw-r--r--sqlglot/parser.py2190
-rw-r--r--sqlglot/planner.py340
-rw-r--r--sqlglot/time.py45
-rw-r--r--sqlglot/tokens.py853
-rw-r--r--sqlglot/transforms.py68
-rw-r--r--sqlglot/trie.py27
50 files changed, 12938 insertions, 0 deletions
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
new file mode 100644
index 0000000..0007e34
--- /dev/null
+++ b/sqlglot/__init__.py
@@ -0,0 +1,96 @@
+from sqlglot import expressions as exp
+from sqlglot.dialects import Dialect, Dialects
+from sqlglot.diff import diff
+from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError
+from sqlglot.expressions import Expression
+from sqlglot.expressions import alias_ as alias
+from sqlglot.expressions import (
+ and_,
+ column,
+ condition,
+ from_,
+ maybe_parse,
+ not_,
+ or_,
+ select,
+ subquery,
+)
+from sqlglot.expressions import table_ as table
+from sqlglot.generator import Generator
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+__version__ = "6.0.4"
+
+pretty = False
+
+
+def parse(sql, read=None, **opts):
+ """
+ Parses the given SQL string into a collection of syntax trees, one per
+ parsed SQL statement.
+
+ Args:
+ sql (str): the SQL code string to parse.
+ read (str): the SQL dialect to apply during parsing
+ (eg. "spark", "hive", "presto", "mysql").
+ **opts: other options.
+
+ Returns:
+ typing.List[Expression]: the list of parsed syntax trees.
+ """
+ dialect = Dialect.get_or_raise(read)()
+ return dialect.parse(sql, **opts)
+
+
+def parse_one(sql, read=None, into=None, **opts):
+ """
+ Parses the given SQL string and returns a syntax tree for the first
+ parsed SQL statement.
+
+ Args:
+ sql (str): the SQL code string to parse.
+ read (str): the SQL dialect to apply during parsing
+ (eg. "spark", "hive", "presto", "mysql").
+ into (Expression): the SQLGlot Expression to parse into
+ **opts: other options.
+
+ Returns:
+ Expression: the syntax tree for the first parsed statement.
+ """
+
+ dialect = Dialect.get_or_raise(read)()
+
+ if into:
+ result = dialect.parse_into(into, sql, **opts)
+ else:
+ result = dialect.parse(sql, **opts)
+
+ return result[0] if result else None
+
+
+def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts):
+ """
+ Parses the given SQL string using the source dialect and returns a list of SQL strings
+ transformed to conform to the target dialect. Each string in the returned list represents
+ a single transformed SQL statement.
+
+ Args:
+ sql (str): the SQL code string to transpile.
+ read (str): the source dialect used to parse the input string
+ (eg. "spark", "hive", "presto", "mysql").
+ write (str): the target dialect into which the input should be transformed
+ (eg. "spark", "hive", "presto", "mysql").
+ identity (bool): if set to True and if the target dialect is not specified
+ the source dialect will be used as both: the source and the target dialect.
+ error_level (ErrorLevel): the desired error level of the parser.
+ **opts: other options.
+
+ Returns:
+ typing.List[str]: the list of transpiled SQL statements / expressions.
+ """
+ write = write or read if identity else write
+ return [
+ Dialect.get_or_raise(write)().generate(expression, **opts)
+ for expression in parse(sql, read, error_level=error_level)
+ ]
diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py
new file mode 100644
index 0000000..25200c4
--- /dev/null
+++ b/sqlglot/__main__.py
@@ -0,0 +1,69 @@
+import argparse
+
+import sqlglot
+
+parser = argparse.ArgumentParser(description="Transpile SQL")
+parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile")
+parser.add_argument(
+ "--read",
+ dest="read",
+ type=str,
+ default=None,
+ help="Dialect to read default is generic",
+)
+parser.add_argument(
+ "--write",
+ dest="write",
+ type=str,
+ default=None,
+ help="Dialect to write default is generic",
+)
+parser.add_argument(
+ "--no-identify",
+ dest="identify",
+ action="store_false",
+ help="Don't auto identify fields",
+)
+parser.add_argument(
+ "--no-pretty",
+ dest="pretty",
+ action="store_false",
+ help="Compress sql",
+)
+parser.add_argument(
+ "--parse",
+ dest="parse",
+ action="store_true",
+ help="Parse and return the expression tree",
+)
+parser.add_argument(
+ "--error-level",
+ dest="error_level",
+ type=str,
+ default="RAISE",
+ help="IGNORE, WARN, RAISE (default)",
+)
+
+
+args = parser.parse_args()
+error_level = sqlglot.ErrorLevel[args.error_level.upper()]
+
+if args.parse:
+ sqls = [
+ repr(expression)
+ for expression in sqlglot.parse(
+ args.sql, read=args.read, error_level=error_level
+ )
+ ]
+else:
+ sqls = sqlglot.transpile(
+ args.sql,
+ read=args.read,
+ write=args.write,
+ identify=args.identify,
+ pretty=args.pretty,
+ error_level=error_level,
+ )
+
+for sql in sqls:
+ print(sql)
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
new file mode 100644
index 0000000..5aa7d77
--- /dev/null
+++ b/sqlglot/dialects/__init__.py
@@ -0,0 +1,15 @@
+from sqlglot.dialects.bigquery import BigQuery
+from sqlglot.dialects.clickhouse import ClickHouse
+from sqlglot.dialects.dialect import Dialect, Dialects
+from sqlglot.dialects.duckdb import DuckDB
+from sqlglot.dialects.hive import Hive
+from sqlglot.dialects.mysql import MySQL
+from sqlglot.dialects.oracle import Oracle
+from sqlglot.dialects.postgres import Postgres
+from sqlglot.dialects.presto import Presto
+from sqlglot.dialects.snowflake import Snowflake
+from sqlglot.dialects.spark import Spark
+from sqlglot.dialects.sqlite import SQLite
+from sqlglot.dialects.starrocks import StarRocks
+from sqlglot.dialects.tableau import Tableau
+from sqlglot.dialects.trino import Trino
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
new file mode 100644
index 0000000..f4e87c3
--- /dev/null
+++ b/sqlglot/dialects/bigquery.py
@@ -0,0 +1,128 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import (
+ Dialect,
+ inline_array_sql,
+ no_ilike_sql,
+ rename_func,
+)
+from sqlglot.generator import Generator
+from sqlglot.helper import list_get
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _date_add(expression_class):
+ def func(args):
+ interval = list_get(args, 1)
+ return expression_class(
+ this=list_get(args, 0),
+ expression=interval.this,
+ unit=interval.args.get("unit"),
+ )
+
+ return func
+
+
+def _date_add_sql(data_type, kind):
+ def func(self, expression):
+ this = self.sql(expression, "this")
+ unit = self.sql(expression, "unit") or "'day'"
+ expression = self.sql(expression, "expression")
+ return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})"
+
+ return func
+
+
+class BigQuery(Dialect):
+ unnest_column_only = True
+
+ class Tokenizer(Tokenizer):
+ QUOTES = [
+ (prefix + quote, quote) if prefix else quote
+ for quote in ["'", '"', '"""', "'''"]
+ for prefix in ["", "r", "R"]
+ ]
+ IDENTIFIERS = ["`"]
+ ESCAPE = "\\"
+
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "CURRENT_DATETIME": TokenType.CURRENT_DATETIME,
+ "CURRENT_TIME": TokenType.CURRENT_TIME,
+ "GEOGRAPHY": TokenType.GEOGRAPHY,
+ "INT64": TokenType.BIGINT,
+ "FLOAT64": TokenType.DOUBLE,
+ "QUALIFY": TokenType.QUALIFY,
+ "UNKNOWN": TokenType.NULL,
+ "WINDOW": TokenType.WINDOW,
+ }
+
+ class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "DATE_ADD": _date_add(exp.DateAdd),
+ "DATETIME_ADD": _date_add(exp.DatetimeAdd),
+ "TIME_ADD": _date_add(exp.TimeAdd),
+ "TIMESTAMP_ADD": _date_add(exp.TimestampAdd),
+ "DATE_SUB": _date_add(exp.DateSub),
+ "DATETIME_SUB": _date_add(exp.DatetimeSub),
+ "TIME_SUB": _date_add(exp.TimeSub),
+ "TIMESTAMP_SUB": _date_add(exp.TimestampSub),
+ }
+
+ NO_PAREN_FUNCTIONS = {
+ **Parser.NO_PAREN_FUNCTIONS,
+ TokenType.CURRENT_DATETIME: exp.CurrentDatetime,
+ TokenType.CURRENT_TIME: exp.CurrentTime,
+ }
+
+ class Generator(Generator):
+ TRANSFORMS = {
+ exp.Array: inline_array_sql,
+ exp.ArraySize: rename_func("ARRAY_LENGTH"),
+ exp.DateAdd: _date_add_sql("DATE", "ADD"),
+ exp.DateSub: _date_add_sql("DATE", "SUB"),
+ exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"),
+ exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"),
+ exp.ILike: no_ilike_sql,
+ exp.TimeAdd: _date_add_sql("TIME", "ADD"),
+ exp.TimeSub: _date_add_sql("TIME", "SUB"),
+ exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"),
+ exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"),
+ exp.VariancePop: rename_func("VAR_POP"),
+ }
+
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.TINYINT: "INT64",
+ exp.DataType.Type.SMALLINT: "INT64",
+ exp.DataType.Type.INT: "INT64",
+ exp.DataType.Type.BIGINT: "INT64",
+ exp.DataType.Type.DECIMAL: "NUMERIC",
+ exp.DataType.Type.FLOAT: "FLOAT64",
+ exp.DataType.Type.DOUBLE: "FLOAT64",
+ exp.DataType.Type.BOOLEAN: "BOOL",
+ exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.VARCHAR: "STRING",
+ exp.DataType.Type.NVARCHAR: "STRING",
+ }
+
+ def in_unnest_op(self, unnest):
+ return self.sql(unnest)
+
+ def union_op(self, expression):
+ return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
+
+ def except_op(self, expression):
+ if not expression.args.get("distinct", False):
+ self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery")
+ return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
+
+ def intersect_op(self, expression):
+ if not expression.args.get("distinct", False):
+ self.unsupported(
+ "INTERSECT without DISTINCT is not supported in BigQuery"
+ )
+ return (
+ f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}"
+ )
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
new file mode 100644
index 0000000..55dad7a
--- /dev/null
+++ b/sqlglot/dialects/clickhouse.py
@@ -0,0 +1,48 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import Dialect, inline_array_sql
+from sqlglot.generator import Generator
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+class ClickHouse(Dialect):
+ normalize_functions = None
+ null_ordering = "nulls_are_last"
+
+ class Tokenizer(Tokenizer):
+ IDENTIFIERS = ['"', "`"]
+
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "NULLABLE": TokenType.NULLABLE,
+ "FINAL": TokenType.FINAL,
+ "INT8": TokenType.TINYINT,
+ "INT16": TokenType.SMALLINT,
+ "INT32": TokenType.INT,
+ "INT64": TokenType.BIGINT,
+ "FLOAT32": TokenType.FLOAT,
+ "FLOAT64": TokenType.DOUBLE,
+ }
+
+ class Parser(Parser):
+ def _parse_table(self, schema=False):
+ this = super()._parse_table(schema)
+
+ if self._match(TokenType.FINAL):
+ this = self.expression(exp.Final, this=this)
+
+ return this
+
+ class Generator(Generator):
+ STRUCT_DELIMITER = ("(", ")")
+
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.NULLABLE: "Nullable",
+ }
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.Array: inline_array_sql,
+ exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL",
+ }
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
new file mode 100644
index 0000000..8045f7a
--- /dev/null
+++ b/sqlglot/dialects/dialect.py
@@ -0,0 +1,268 @@
+from enum import Enum
+
+from sqlglot import exp
+from sqlglot.generator import Generator
+from sqlglot.helper import csv, list_get
+from sqlglot.parser import Parser
+from sqlglot.time import format_time
+from sqlglot.tokens import Tokenizer
+from sqlglot.trie import new_trie
+
+
+class Dialects(str, Enum):
+ DIALECT = ""
+
+ BIGQUERY = "bigquery"
+ CLICKHOUSE = "clickhouse"
+ DUCKDB = "duckdb"
+ HIVE = "hive"
+ MYSQL = "mysql"
+ ORACLE = "oracle"
+ POSTGRES = "postgres"
+ PRESTO = "presto"
+ SNOWFLAKE = "snowflake"
+ SPARK = "spark"
+ SQLITE = "sqlite"
+ STARROCKS = "starrocks"
+ TABLEAU = "tableau"
+ TRINO = "trino"
+
+
+class _Dialect(type):
+ classes = {}
+
+ @classmethod
+ def __getitem__(cls, key):
+ return cls.classes[key]
+
+ @classmethod
+ def get(cls, key, default=None):
+ return cls.classes.get(key, default)
+
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+ enum = Dialects.__members__.get(clsname.upper())
+ cls.classes[enum.value if enum is not None else clsname.lower()] = klass
+
+ klass.time_trie = new_trie(klass.time_mapping)
+ klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()}
+ klass.inverse_time_trie = new_trie(klass.inverse_time_mapping)
+
+ klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer)
+ klass.parser_class = getattr(klass, "Parser", Parser)
+ klass.generator_class = getattr(klass, "Generator", Generator)
+
+ klass.tokenizer = klass.tokenizer_class()
+ klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[
+ 0
+ ]
+ klass.identifier_start, klass.identifier_end = list(
+ klass.tokenizer_class.IDENTIFIERS.items()
+ )[0]
+
+ return klass
+
+
+class Dialect(metaclass=_Dialect):
+ index_offset = 0
+ unnest_column_only = False
+ alias_post_tablesample = False
+ normalize_functions = "upper"
+ null_ordering = "nulls_are_small"
+
+ date_format = "'%Y-%m-%d'"
+ dateint_format = "'%Y%m%d'"
+ time_format = "'%Y-%m-%d %H:%M:%S'"
+ time_mapping = {}
+
+ # autofilled
+ quote_start = None
+ quote_end = None
+ identifier_start = None
+ identifier_end = None
+
+ time_trie = None
+ inverse_time_mapping = None
+ inverse_time_trie = None
+ tokenizer_class = None
+ parser_class = None
+ generator_class = None
+ tokenizer = None
+
+ @classmethod
+ def get_or_raise(cls, dialect):
+ if not dialect:
+ return cls
+ result = cls.get(dialect)
+ if not result:
+ raise ValueError(f"Unknown dialect '{dialect}'")
+ return result
+
+ @classmethod
+ def format_time(cls, expression):
+ if isinstance(expression, str):
+ return exp.Literal.string(
+ format_time(
+ expression[1:-1], # the time formats are quoted
+ cls.time_mapping,
+ cls.time_trie,
+ )
+ )
+ if expression and expression.is_string:
+ return exp.Literal.string(
+ format_time(
+ expression.this,
+ cls.time_mapping,
+ cls.time_trie,
+ )
+ )
+ return expression
+
+ def parse(self, sql, **opts):
+ return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
+
+ def parse_into(self, expression_type, sql, **opts):
+ return self.parser(**opts).parse_into(
+ expression_type, self.tokenizer.tokenize(sql), sql
+ )
+
+ def generate(self, expression, **opts):
+ return self.generator(**opts).generate(expression)
+
+ def transpile(self, code, **opts):
+ return self.generate(self.parse(code), **opts)
+
+ def parser(self, **opts):
+ return self.parser_class(
+ **{
+ "index_offset": self.index_offset,
+ "unnest_column_only": self.unnest_column_only,
+ "alias_post_tablesample": self.alias_post_tablesample,
+ "null_ordering": self.null_ordering,
+ **opts,
+ },
+ )
+
+ def generator(self, **opts):
+ return self.generator_class(
+ **{
+ "quote_start": self.quote_start,
+ "quote_end": self.quote_end,
+ "identifier_start": self.identifier_start,
+ "identifier_end": self.identifier_end,
+ "escape": self.tokenizer_class.ESCAPE,
+ "index_offset": self.index_offset,
+ "time_mapping": self.inverse_time_mapping,
+ "time_trie": self.inverse_time_trie,
+ "unnest_column_only": self.unnest_column_only,
+ "alias_post_tablesample": self.alias_post_tablesample,
+ "normalize_functions": self.normalize_functions,
+ "null_ordering": self.null_ordering,
+ **opts,
+ }
+ )
+
+
+def rename_func(name):
+ return (
+ lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})"
+ )
+
+
+def approx_count_distinct_sql(self, expression):
+ if expression.args.get("accuracy"):
+ self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy")
+ return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})"
+
+
+def if_sql(self, expression):
+ expressions = csv(
+ self.sql(expression, "this"),
+ self.sql(expression, "true"),
+ self.sql(expression, "false"),
+ )
+ return f"IF({expressions})"
+
+
+def arrow_json_extract_sql(self, expression):
+ return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}"
+
+
+def arrow_json_extract_scalar_sql(self, expression):
+ return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}"
+
+
+def inline_array_sql(self, expression):
+ return f"[{self.expressions(expression)}]"
+
+
+def no_ilike_sql(self, expression):
+ return self.like_sql(
+ exp.Like(
+ this=exp.Lower(this=expression.this),
+ expression=expression.args["expression"],
+ )
+ )
+
+
+def no_paren_current_date_sql(self, expression):
+ zone = self.sql(expression, "this")
+ return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE"
+
+
+def no_recursive_cte_sql(self, expression):
+ if expression.args.get("recursive"):
+ self.unsupported("Recursive CTEs are unsupported")
+ expression.args["recursive"] = False
+ return self.with_sql(expression)
+
+
+def no_safe_divide_sql(self, expression):
+ n = self.sql(expression, "this")
+ d = self.sql(expression, "expression")
+ return f"IF({d} <> 0, {n} / {d}, NULL)"
+
+
+def no_tablesample_sql(self, expression):
+ self.unsupported("TABLESAMPLE unsupported")
+ return self.sql(expression.this)
+
+
+def no_trycast_sql(self, expression):
+ return self.cast_sql(expression)
+
+
+def str_position_sql(self, expression):
+ this = self.sql(expression, "this")
+ substr = self.sql(expression, "substr")
+ position = self.sql(expression, "position")
+ if position:
+ return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1"
+ return f"STRPOS({this}, {substr})"
+
+
+def struct_extract_sql(self, expression):
+ this = self.sql(expression, "this")
+ struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True))
+ return f"{this}.{struct_key}"
+
+
+def format_time_lambda(exp_class, dialect, default=None):
+ """Helper used for time expressions.
+
+ Args
+ exp_class (Class): the expression class to instantiate
+ dialect (string): sql dialect
+ default (Option[bool | str]): the default format, True being time
+ """
+
+ def _format_time(args):
+ return exp_class(
+ this=list_get(args, 0),
+ format=Dialect[dialect].format_time(
+ list_get(args, 1)
+ or (Dialect[dialect].time_format if default is True else default)
+ ),
+ )
+
+ return _format_time
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
new file mode 100644
index 0000000..d83a620
--- /dev/null
+++ b/sqlglot/dialects/duckdb.py
@@ -0,0 +1,156 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import (
+ Dialect,
+ approx_count_distinct_sql,
+ arrow_json_extract_scalar_sql,
+ arrow_json_extract_sql,
+ format_time_lambda,
+ no_safe_divide_sql,
+ no_tablesample_sql,
+ rename_func,
+ str_position_sql,
+)
+from sqlglot.generator import Generator
+from sqlglot.helper import list_get
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _unix_to_time(self, expression):
+ return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))"
+
+
+def _str_to_time_sql(self, expression):
+ return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})"
+
+
+def _ts_or_ds_add(self, expression):
+ this = self.sql(expression, "this")
+ e = self.sql(expression, "expression")
+ unit = self.sql(expression, "unit").strip("'") or "DAY"
+ return f"CAST({this} AS DATE) + INTERVAL {e} {unit}"
+
+
+def _ts_or_ds_to_date_sql(self, expression):
+ time_format = self.format_time(expression)
+ if time_format and time_format not in (DuckDB.time_format, DuckDB.date_format):
+ return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
+ return f"CAST({self.sql(expression, 'this')} AS DATE)"
+
+
+def _date_add(self, expression):
+ this = self.sql(expression, "this")
+ e = self.sql(expression, "expression")
+ unit = self.sql(expression, "unit").strip("'") or "DAY"
+ return f"{this} + INTERVAL {e} {unit}"
+
+
+def _array_sort_sql(self, expression):
+ if expression.expression:
+ self.unsupported("DUCKDB ARRAY_SORT does not support a comparator")
+ return f"ARRAY_SORT({self.sql(expression, 'this')})"
+
+
+def _sort_array_sql(self, expression):
+ this = self.sql(expression, "this")
+ if expression.args.get("asc") == exp.FALSE:
+ return f"ARRAY_REVERSE_SORT({this})"
+ return f"ARRAY_SORT({this})"
+
+
+def _sort_array_reverse(args):
+ return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE)
+
+
+def _struct_pack_sql(self, expression):
+ args = [
+ self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e)
+ for e in expression.expressions
+ ]
+ return f"STRUCT_PACK({', '.join(args)})"
+
+
+class DuckDB(Dialect):
+ class Tokenizer(Tokenizer):
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ ":=": TokenType.EQ,
+ }
+
+ class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "ARRAY_LENGTH": exp.ArraySize.from_arg_list,
+ "ARRAY_SORT": exp.SortArray.from_arg_list,
+ "ARRAY_REVERSE_SORT": _sort_array_reverse,
+ "EPOCH": exp.TimeToUnix.from_arg_list,
+ "EPOCH_MS": lambda args: exp.UnixToTime(
+ this=exp.Div(
+ this=list_get(args, 0),
+ expression=exp.Literal.number(1000),
+ )
+ ),
+ "LIST_SORT": exp.SortArray.from_arg_list,
+ "LIST_REVERSE_SORT": _sort_array_reverse,
+ "LIST_VALUE": exp.Array.from_arg_list,
+ "REGEXP_MATCHES": exp.RegexpLike.from_arg_list,
+ "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"),
+ "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"),
+ "STR_SPLIT": exp.Split.from_arg_list,
+ "STRING_SPLIT": exp.Split.from_arg_list,
+ "STRING_TO_ARRAY": exp.Split.from_arg_list,
+ "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
+ "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list,
+ "STRUCT_PACK": exp.Struct.from_arg_list,
+ "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list,
+ "UNNEST": exp.Explode.from_arg_list,
+ }
+
+ class Generator(Generator):
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})",
+ exp.ArraySize: rename_func("ARRAY_LENGTH"),
+ exp.ArraySort: _array_sort_sql,
+ exp.ArraySum: rename_func("LIST_SUM"),
+ exp.DateAdd: _date_add,
+ exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
+ exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)",
+ exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)",
+ exp.Explode: rename_func("UNNEST"),
+ exp.JSONExtract: arrow_json_extract_sql,
+ exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONBExtract: arrow_json_extract_sql,
+ exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
+ exp.RegexpLike: rename_func("REGEXP_MATCHES"),
+ exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
+ exp.SafeDivide: no_safe_divide_sql,
+ exp.Split: rename_func("STR_SPLIT"),
+ exp.SortArray: _sort_array_sql,
+ exp.StrPosition: str_position_sql,
+ exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
+ exp.StrToTime: _str_to_time_sql,
+ exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.Struct: _struct_pack_sql,
+ exp.TableSample: no_tablesample_sql,
+ exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)",
+ exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))",
+ exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToUnix: rename_func("EPOCH"),
+ exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDsAdd: _ts_or_ds_add,
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})",
+ exp.UnixToTime: _unix_to_time,
+ exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)",
+ }
+
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.VARCHAR: "TEXT",
+ exp.DataType.Type.NVARCHAR: "TEXT",
+ }
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
new file mode 100644
index 0000000..e3f3f39
--- /dev/null
+++ b/sqlglot/dialects/hive.py
@@ -0,0 +1,304 @@
+from sqlglot import exp, transforms
+from sqlglot.dialects.dialect import (
+ Dialect,
+ approx_count_distinct_sql,
+ format_time_lambda,
+ if_sql,
+ no_ilike_sql,
+ no_recursive_cte_sql,
+ no_safe_divide_sql,
+ no_trycast_sql,
+ rename_func,
+ struct_extract_sql,
+)
+from sqlglot.generator import Generator
+from sqlglot.helper import csv, list_get
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer
+
+
+def _parse_map(args):
+ keys = []
+ values = []
+ for i in range(0, len(args), 2):
+ keys.append(args[i])
+ values.append(args[i + 1])
+ return HiveMap(
+ keys=exp.Array(expressions=keys),
+ values=exp.Array(expressions=values),
+ )
+
+
+def _map_sql(self, expression):
+ keys = expression.args["keys"]
+ values = expression.args["values"]
+
+ if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array):
+ self.unsupported("Cannot convert array columns into map use SparkSQL instead.")
+ return f"MAP({self.sql(keys)}, {self.sql(values)})"
+
+ args = []
+ for key, value in zip(keys.expressions, values.expressions):
+ args.append(self.sql(key))
+ args.append(self.sql(value))
+ return f"MAP({csv(*args)})"
+
+
+def _array_sort(self, expression):
+ if expression.expression:
+ self.unsupported("Hive SORT_ARRAY does not support a comparator")
+ return f"SORT_ARRAY({self.sql(expression, 'this')})"
+
+
+def _property_sql(self, expression):
+ key = expression.name
+ value = self.sql(expression, "value")
+ return f"'{key}' = {value}"
+
+
+def _str_to_unix(self, expression):
+ return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})"
+
+
+def _str_to_date(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format not in (Hive.time_format, Hive.date_format):
+ this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
+ return f"CAST({this} AS DATE)"
+
+
+def _str_to_time(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format not in (Hive.time_format, Hive.date_format):
+ this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))"
+ return f"CAST({this} AS TIMESTAMP)"
+
+
+def _time_format(self, expression):
+ time_format = self.format_time(expression)
+ if time_format == Hive.time_format:
+ return None
+ return time_format
+
+
+def _time_to_str(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ return f"DATE_FORMAT({this}, {time_format})"
+
+
+def _to_date_sql(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format and time_format not in (Hive.time_format, Hive.date_format):
+ return f"TO_DATE({this}, {time_format})"
+ return f"TO_DATE({this})"
+
+
+def _unnest_to_explode_sql(self, expression):
+ unnest = expression.this
+ if isinstance(unnest, exp.Unnest):
+ alias = unnest.args.get("alias")
+ udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
+ return "".join(
+ self.sql(
+ exp.Lateral(
+ this=udtf(this=expression),
+ alias=exp.TableAlias(this=alias.this, columns=[column]),
+ )
+ )
+ for expression, column in zip(
+ unnest.expressions, alias.columns if alias else []
+ )
+ )
+ return self.join_sql(expression)
+
+
+def _index_sql(self, expression):
+ this = self.sql(expression, "this")
+ table = self.sql(expression, "table")
+ columns = self.sql(expression, "columns")
+ return f"{this} ON TABLE {table} {columns}"
+
+
+class HiveMap(exp.Map):
+ is_var_len_args = True
+
+
+class Hive(Dialect):
+ alias_post_tablesample = True
+
+ time_mapping = {
+ "y": "%Y",
+ "Y": "%Y",
+ "YYYY": "%Y",
+ "yyyy": "%Y",
+ "YY": "%y",
+ "yy": "%y",
+ "MMMM": "%B",
+ "MMM": "%b",
+ "MM": "%m",
+ "M": "%-m",
+ "dd": "%d",
+ "d": "%-d",
+ "HH": "%H",
+ "H": "%-H",
+ "hh": "%I",
+ "h": "%-I",
+ "mm": "%M",
+ "m": "%-M",
+ "ss": "%S",
+ "s": "%-S",
+ "S": "%f",
+ }
+
+ date_format = "'yyyy-MM-dd'"
+ dateint_format = "'yyyyMMdd'"
+ time_format = "'yyyy-MM-dd HH:mm:ss'"
+
+ class Tokenizer(Tokenizer):
+ QUOTES = ["'", '"']
+ IDENTIFIERS = ["`"]
+ ESCAPE = "\\"
+ ENCODE = "utf-8"
+
+ NUMERIC_LITERALS = {
+ "L": "BIGINT",
+ "S": "SMALLINT",
+ "Y": "TINYINT",
+ "D": "DOUBLE",
+ "F": "FLOAT",
+ "BD": "DECIMAL",
+ }
+
+ class Parser(Parser):
+ STRICT_CAST = False
+
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "COLLECT_LIST": exp.ArrayAgg.from_arg_list,
+ "DATE_ADD": lambda args: exp.TsOrDsAdd(
+ this=list_get(args, 0),
+ expression=list_get(args, 1),
+ unit=exp.Literal.string("DAY"),
+ ),
+ "DATEDIFF": lambda args: exp.DateDiff(
+ this=exp.TsOrDsToDate(this=list_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=list_get(args, 1)),
+ ),
+ "DATE_SUB": lambda args: exp.TsOrDsAdd(
+ this=list_get(args, 0),
+ expression=exp.Mul(
+ this=list_get(args, 1),
+ expression=exp.Literal.number(-1),
+ ),
+ unit=exp.Literal.string("DAY"),
+ ),
+ "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"),
+ "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))),
+ "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
+ "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list,
+ "LOCATE": lambda args: exp.StrPosition(
+ this=list_get(args, 1),
+ substr=list_get(args, 0),
+ position=list_get(args, 2),
+ ),
+ "LOG": (
+ lambda args: exp.Log.from_arg_list(args)
+ if len(args) > 1
+ else exp.Ln.from_arg_list(args)
+ ),
+ "MAP": _parse_map,
+ "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
+ "PERCENTILE": exp.Quantile.from_arg_list,
+ "COLLECT_SET": exp.SetAgg.from_arg_list,
+ "SIZE": exp.ArraySize.from_arg_list,
+ "SPLIT": exp.RegexpSplit.from_arg_list,
+ "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
+ "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True),
+ "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
+ }
+
+ class Generator(Generator):
+ ROOT_PROPERTIES = [
+ exp.PartitionedByProperty,
+ exp.FileFormatProperty,
+ exp.SchemaCommentProperty,
+ exp.LocationProperty,
+ exp.TableFormatProperty,
+ ]
+ WITH_PROPERTIES = [exp.AnonymousProperty]
+
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.TEXT: "STRING",
+ }
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP,
+ exp.AnonymousProperty: _property_sql,
+ exp.ApproxDistinct: approx_count_distinct_sql,
+ exp.ArrayAgg: rename_func("COLLECT_LIST"),
+ exp.ArraySize: rename_func("SIZE"),
+ exp.ArraySort: _array_sort,
+ exp.With: no_recursive_cte_sql,
+ exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.DateStrToDate: rename_func("TO_DATE"),
+ exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
+ exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
+ exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}",
+ exp.If: if_sql,
+ exp.Index: _index_sql,
+ exp.ILike: no_ilike_sql,
+ exp.Join: _unnest_to_explode_sql,
+ exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
+ exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
+ exp.Map: _map_sql,
+ HiveMap: _map_sql,
+ exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}",
+ exp.Quantile: rename_func("PERCENTILE"),
+ exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
+ exp.RegexpSplit: rename_func("SPLIT"),
+ exp.SafeDivide: no_safe_divide_sql,
+ exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}",
+ exp.SetAgg: rename_func("COLLECT_SET"),
+ exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
+ exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})",
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: _str_to_time,
+ exp.StrToUnix: _str_to_unix,
+ exp.StructExtract: struct_extract_sql,
+ exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}",
+ exp.TimeStrToDate: rename_func("TO_DATE"),
+ exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
+ exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TimeToStr: _time_to_str,
+ exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"),
+ exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.TsOrDsToDate: _to_date_sql,
+ exp.TryCast: no_trycast_sql,
+ exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})",
+ exp.UnixToTime: rename_func("FROM_UNIXTIME"),
+ exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"),
+ }
+
+ def with_properties(self, properties):
+ return self.properties(
+ properties,
+ prefix="TBLPROPERTIES",
+ )
+
+ def datatype_sql(self, expression):
+ if (
+ expression.this
+ in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR)
+ and not expression.expressions
+ ):
+ expression = exp.DataType.build("text")
+ return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
new file mode 100644
index 0000000..93800a6
--- /dev/null
+++ b/sqlglot/dialects/mysql.py
@@ -0,0 +1,163 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import (
+ Dialect,
+ no_ilike_sql,
+ no_paren_current_date_sql,
+ no_tablesample_sql,
+ no_trycast_sql,
+)
+from sqlglot.generator import Generator
+from sqlglot.helper import list_get
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _date_trunc_sql(self, expression):
+ unit = expression.text("unit").lower()
+
+ this = self.sql(expression.this)
+
+ if unit == "day":
+ return f"DATE({this})"
+
+ if unit == "week":
+ concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')"
+ date_format = "%Y %u %w"
+ elif unit == "month":
+ concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')"
+ date_format = "%Y %c %e"
+ elif unit == "quarter":
+ concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')"
+ date_format = "%Y %c %e"
+ elif unit == "year":
+ concat = f"CONCAT(YEAR({this}), ' 1 1')"
+ date_format = "%Y %c %e"
+ else:
+ self.unsupported("Unexpected interval unit: {unit}")
+ return f"DATE({this})"
+
+ return f"STR_TO_DATE({concat}, '{date_format}')"
+
+
+def _str_to_date(args):
+ date_format = MySQL.format_time(list_get(args, 1))
+ return exp.StrToDate(this=list_get(args, 0), format=date_format)
+
+
+def _str_to_date_sql(self, expression):
+ date_format = self.format_time(expression)
+ return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})"
+
+
+def _date_add(expression_class):
+ def func(args):
+ interval = list_get(args, 1)
+ return expression_class(
+ this=list_get(args, 0),
+ expression=interval.this,
+ unit=exp.Literal.string(interval.text("unit").lower()),
+ )
+
+ return func
+
+
+def _date_add_sql(kind):
+ def func(self, expression):
+ this = self.sql(expression, "this")
+ unit = expression.text("unit").upper() or "DAY"
+ expression = self.sql(expression, "expression")
+ return f"DATE_{kind}({this}, INTERVAL {expression} {unit})"
+
+ return func
+
+
+class MySQL(Dialect):
+ # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions
+ time_mapping = {
+ "%M": "%B",
+ "%c": "%-m",
+ "%e": "%-d",
+ "%h": "%I",
+ "%i": "%M",
+ "%s": "%S",
+ "%S": "%S",
+ "%u": "%W",
+ }
+
+ class Tokenizer(Tokenizer):
+ QUOTES = ["'", '"']
+ COMMENTS = ["--", "#", ("/*", "*/")]
+ IDENTIFIERS = ["`"]
+
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "_ARMSCII8": TokenType.INTRODUCER,
+ "_ASCII": TokenType.INTRODUCER,
+ "_BIG5": TokenType.INTRODUCER,
+ "_BINARY": TokenType.INTRODUCER,
+ "_CP1250": TokenType.INTRODUCER,
+ "_CP1251": TokenType.INTRODUCER,
+ "_CP1256": TokenType.INTRODUCER,
+ "_CP1257": TokenType.INTRODUCER,
+ "_CP850": TokenType.INTRODUCER,
+ "_CP852": TokenType.INTRODUCER,
+ "_CP866": TokenType.INTRODUCER,
+ "_CP932": TokenType.INTRODUCER,
+ "_DEC8": TokenType.INTRODUCER,
+ "_EUCJPMS": TokenType.INTRODUCER,
+ "_EUCKR": TokenType.INTRODUCER,
+ "_GB18030": TokenType.INTRODUCER,
+ "_GB2312": TokenType.INTRODUCER,
+ "_GBK": TokenType.INTRODUCER,
+ "_GEOSTD8": TokenType.INTRODUCER,
+ "_GREEK": TokenType.INTRODUCER,
+ "_HEBREW": TokenType.INTRODUCER,
+ "_HP8": TokenType.INTRODUCER,
+ "_KEYBCS2": TokenType.INTRODUCER,
+ "_KOI8R": TokenType.INTRODUCER,
+ "_KOI8U": TokenType.INTRODUCER,
+ "_LATIN1": TokenType.INTRODUCER,
+ "_LATIN2": TokenType.INTRODUCER,
+ "_LATIN5": TokenType.INTRODUCER,
+ "_LATIN7": TokenType.INTRODUCER,
+ "_MACCE": TokenType.INTRODUCER,
+ "_MACROMAN": TokenType.INTRODUCER,
+ "_SJIS": TokenType.INTRODUCER,
+ "_SWE7": TokenType.INTRODUCER,
+ "_TIS620": TokenType.INTRODUCER,
+ "_UCS2": TokenType.INTRODUCER,
+ "_UJIS": TokenType.INTRODUCER,
+ "_UTF8": TokenType.INTRODUCER,
+ "_UTF16": TokenType.INTRODUCER,
+ "_UTF16LE": TokenType.INTRODUCER,
+ "_UTF32": TokenType.INTRODUCER,
+ "_UTF8MB3": TokenType.INTRODUCER,
+ "_UTF8MB4": TokenType.INTRODUCER,
+ }
+
+ class Parser(Parser):
+ STRICT_CAST = False
+
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "DATE_ADD": _date_add(exp.DateAdd),
+ "DATE_SUB": _date_add(exp.DateSub),
+ "STR_TO_DATE": _str_to_date,
+ }
+
+ class Generator(Generator):
+ NULL_ORDERING_SUPPORTED = False
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.CurrentDate: no_paren_current_date_sql,
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.ILike: no_ilike_sql,
+ exp.TableSample: no_tablesample_sql,
+ exp.TryCast: no_trycast_sql,
+ exp.DateAdd: _date_add_sql("ADD"),
+ exp.DateSub: _date_add_sql("SUB"),
+ exp.DateTrunc: _date_trunc_sql,
+ exp.StrToDate: _str_to_date_sql,
+ exp.StrToTime: _str_to_date_sql,
+ }
diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py
new file mode 100644
index 0000000..9c8b6f2
--- /dev/null
+++ b/sqlglot/dialects/oracle.py
@@ -0,0 +1,63 @@
+from sqlglot import exp, transforms
+from sqlglot.dialects.dialect import Dialect, no_ilike_sql
+from sqlglot.generator import Generator
+from sqlglot.helper import csv
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _limit_sql(self, expression):
+ return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression))
+
+
+class Oracle(Dialect):
+ class Generator(Generator):
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.TINYINT: "NUMBER",
+ exp.DataType.Type.SMALLINT: "NUMBER",
+ exp.DataType.Type.INT: "NUMBER",
+ exp.DataType.Type.BIGINT: "NUMBER",
+ exp.DataType.Type.DECIMAL: "NUMBER",
+ exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
+ exp.DataType.Type.VARCHAR: "VARCHAR2",
+ exp.DataType.Type.NVARCHAR: "NVARCHAR2",
+ exp.DataType.Type.TEXT: "CLOB",
+ exp.DataType.Type.BINARY: "BLOB",
+ }
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP,
+ exp.ILike: no_ilike_sql,
+ exp.Limit: _limit_sql,
+ }
+
+ def query_modifiers(self, expression, *sqls):
+ return csv(
+ *sqls,
+ *[self.sql(sql) for sql in expression.args.get("laterals", [])],
+ *[self.sql(sql) for sql in expression.args.get("joins", [])],
+ self.sql(expression, "where"),
+ self.sql(expression, "group"),
+ self.sql(expression, "having"),
+ self.sql(expression, "qualify"),
+ self.sql(expression, "window"),
+ self.sql(expression, "distribute"),
+ self.sql(expression, "sort"),
+ self.sql(expression, "cluster"),
+ self.sql(expression, "order"),
+ self.sql(expression, "offset"), # offset before limit in oracle
+ self.sql(expression, "limit"),
+ sep="",
+ )
+
+ def offset_sql(self, expression):
+ return f"{super().offset_sql(expression)} ROWS"
+
+ class Tokenizer(Tokenizer):
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "TOP": TokenType.TOP,
+ "VARCHAR2": TokenType.VARCHAR,
+ "NVARCHAR2": TokenType.NVARCHAR,
+ }
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
new file mode 100644
index 0000000..61dff86
--- /dev/null
+++ b/sqlglot/dialects/postgres.py
@@ -0,0 +1,109 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import (
+ Dialect,
+ arrow_json_extract_scalar_sql,
+ arrow_json_extract_sql,
+ format_time_lambda,
+ no_paren_current_date_sql,
+ no_tablesample_sql,
+ no_trycast_sql,
+)
+from sqlglot.generator import Generator
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _date_add_sql(kind):
+ def func(self, expression):
+ from sqlglot.optimizer.simplify import simplify
+
+ this = self.sql(expression, "this")
+ unit = self.sql(expression, "unit")
+ expression = simplify(expression.args["expression"])
+
+ if not isinstance(expression, exp.Literal):
+ self.unsupported("Cannot add non literal")
+
+ expression = expression.copy()
+ expression.args["is_string"] = True
+ expression = self.sql(expression)
+ return f"{this} {kind} INTERVAL {expression} {unit}"
+
+ return func
+
+
+class Postgres(Dialect):
+ null_ordering = "nulls_are_large"
+ time_format = "'YYYY-MM-DD HH24:MI:SS'"
+ time_mapping = {
+ "AM": "%p", # AM or PM
+ "D": "%w", # 1-based day of week
+ "DD": "%d", # day of month
+ "DDD": "%j", # zero padded day of year
+ "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres
+ "FMDDD": "%-j", # day of year
+ "FMHH12": "%-I", # 9
+ "FMHH24": "%-H", # 9
+ "FMMI": "%-M", # Minute
+ "FMMM": "%-m", # 1
+ "FMSS": "%-S", # Second
+ "HH12": "%I", # 09
+ "HH24": "%H", # 09
+ "MI": "%M", # zero padded minute
+ "MM": "%m", # 01
+ "OF": "%z", # utc offset
+ "SS": "%S", # zero padded second
+ "TMDay": "%A", # TM is locale dependent
+ "TMDy": "%a",
+ "TMMon": "%b", # Sep
+ "TMMonth": "%B", # September
+ "TZ": "%Z", # uppercase timezone name
+ "US": "%f", # zero padded microsecond
+ "WW": "%U", # 1-based week of year
+ "YY": "%y", # 15
+ "YYYY": "%Y", # 2015
+ }
+
+ class Tokenizer(Tokenizer):
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "SERIAL": TokenType.AUTO_INCREMENT,
+ "UUID": TokenType.UUID,
+ }
+
+ class Parser(Parser):
+ STRICT_CAST = False
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
+ "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
+ }
+
+ class Generator(Generator):
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.TINYINT: "SMALLINT",
+ exp.DataType.Type.FLOAT: "REAL",
+ exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
+ exp.DataType.Type.BINARY: "BYTEA",
+ }
+
+ TOKEN_MAPPING = {
+ TokenType.AUTO_INCREMENT: "SERIAL",
+ }
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.JSONExtract: arrow_json_extract_sql,
+ exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
+ exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
+ exp.CurrentDate: no_paren_current_date_sql,
+ exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
+ exp.DateAdd: _date_add_sql("+"),
+ exp.DateSub: _date_add_sql("-"),
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TableSample: no_tablesample_sql,
+ exp.TryCast: no_trycast_sql,
+ }
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
new file mode 100644
index 0000000..ca913e4
--- /dev/null
+++ b/sqlglot/dialects/presto.py
@@ -0,0 +1,216 @@
+from sqlglot import exp, transforms
+from sqlglot.dialects.dialect import (
+ Dialect,
+ format_time_lambda,
+ if_sql,
+ no_ilike_sql,
+ no_safe_divide_sql,
+ rename_func,
+ str_position_sql,
+ struct_extract_sql,
+)
+from sqlglot.dialects.mysql import MySQL
+from sqlglot.generator import Generator
+from sqlglot.helper import csv, list_get
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _approx_distinct_sql(self, expression):
+ accuracy = expression.args.get("accuracy")
+ accuracy = ", " + self.sql(accuracy) if accuracy else ""
+ return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
+
+
+def _concat_ws_sql(self, expression):
+ sep, *args = expression.expressions
+ sep = self.sql(sep)
+ if len(args) > 1:
+ return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})"
+ return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
+
+
+def _datatype_sql(self, expression):
+ sql = self.datatype_sql(expression)
+ if expression.this == exp.DataType.Type.TIMESTAMPTZ:
+ sql = f"{sql} WITH TIME ZONE"
+ return sql
+
+
+def _date_parse_sql(self, expression):
+ return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')"
+
+
+def _explode_to_unnest_sql(self, expression):
+ if isinstance(expression.this, (exp.Explode, exp.Posexplode)):
+ return self.sql(
+ exp.Join(
+ this=exp.Unnest(
+ expressions=[expression.this.this],
+ alias=expression.args.get("alias"),
+ ordinality=isinstance(expression.this, exp.Posexplode),
+ ),
+ kind="cross",
+ )
+ )
+ return self.lateral_sql(expression)
+
+
+def _initcap_sql(self, expression):
+ regex = r"(\w)(\w*)"
+ return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))"
+
+
+def _no_sort_array(self, expression):
+ if expression.args.get("asc") == exp.FALSE:
+ comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END"
+ else:
+ comparator = None
+ args = csv(self.sql(expression, "this"), comparator)
+ return f"ARRAY_SORT({args})"
+
+
+def _schema_sql(self, expression):
+ if isinstance(expression.parent, exp.Property):
+ columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions)
+ return f"ARRAY[{columns}]"
+
+ for schema in expression.parent.find_all(exp.Schema):
+ if isinstance(schema.parent, exp.Property):
+ expression = expression.copy()
+ expression.expressions.extend(schema.expressions)
+
+ return self.schema_sql(expression)
+
+
+def _quantile_sql(self, expression):
+ self.unsupported("Presto does not support exact quantiles")
+ return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})"
+
+
+def _str_to_time_sql(self, expression):
+ return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})"
+
+
+def _ts_or_ds_to_date_sql(self, expression):
+ time_format = self.format_time(expression)
+ if time_format and time_format not in (Presto.time_format, Presto.date_format):
+ return f"CAST({_str_to_time_sql(self, expression)} AS DATE)"
+ return (
+ f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)"
+ )
+
+
+def _ts_or_ds_add_sql(self, expression):
+ this = self.sql(expression, "this")
+ e = self.sql(expression, "expression")
+ unit = self.sql(expression, "unit") or "'day'"
+ return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))"
+
+
+class Presto(Dialect):
+ index_offset = 1
+ null_ordering = "nulls_are_last"
+ time_format = "'%Y-%m-%d %H:%i:%S'"
+ time_mapping = MySQL.time_mapping
+
+ class Tokenizer(Tokenizer):
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "ROW": TokenType.STRUCT,
+ }
+
+ class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list,
+ "CARDINALITY": exp.ArraySize.from_arg_list,
+ "CONTAINS": exp.ArrayContains.from_arg_list,
+ "DATE_ADD": lambda args: exp.DateAdd(
+ this=list_get(args, 2),
+ expression=list_get(args, 1),
+ unit=list_get(args, 0),
+ ),
+ "DATE_DIFF": lambda args: exp.DateDiff(
+ this=list_get(args, 2),
+ expression=list_get(args, 1),
+ unit=list_get(args, 0),
+ ),
+ "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
+ "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
+ "FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
+ "STRPOS": exp.StrPosition.from_arg_list,
+ "TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
+ }
+
+ class Generator(Generator):
+
+ STRUCT_DELIMITER = ("(", ")")
+
+ WITH_PROPERTIES = [
+ exp.PartitionedByProperty,
+ exp.FileFormatProperty,
+ exp.SchemaCommentProperty,
+ exp.AnonymousProperty,
+ exp.TableFormatProperty,
+ ]
+
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.FLOAT: "REAL",
+ exp.DataType.Type.BINARY: "VARBINARY",
+ exp.DataType.Type.TEXT: "VARCHAR",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.STRUCT: "ROW",
+ }
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ **transforms.UNALIAS_GROUP,
+ exp.ApproxDistinct: _approx_distinct_sql,
+ exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
+ exp.ArrayContains: rename_func("CONTAINS"),
+ exp.ArraySize: rename_func("CARDINALITY"),
+ exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})",
+ exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.ConcatWs: _concat_ws_sql,
+ exp.DataType: _datatype_sql,
+ exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
+ exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
+ exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)",
+ exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)",
+ exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)",
+ exp.FileFormatProperty: lambda self, e: self.property_sql(e),
+ exp.If: if_sql,
+ exp.ILike: no_ilike_sql,
+ exp.Initcap: _initcap_sql,
+ exp.Lateral: _explode_to_unnest_sql,
+ exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
+ exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}",
+ exp.Quantile: _quantile_sql,
+ exp.SafeDivide: no_safe_divide_sql,
+ exp.Schema: _schema_sql,
+ exp.SortArray: _no_sort_array,
+ exp.StrPosition: str_position_sql,
+ exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
+ exp.StrToTime: _str_to_time_sql,
+ exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
+ exp.StructExtract: struct_extract_sql,
+ exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'",
+ exp.TimeStrToDate: _date_parse_sql,
+ exp.TimeStrToTime: _date_parse_sql,
+ exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))",
+ exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToUnix: rename_func("TO_UNIXTIME"),
+ exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
+ exp.TsOrDsAdd: _ts_or_ds_add_sql,
+ exp.TsOrDsToDate: _ts_or_ds_to_date_sql,
+ exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
+ exp.UnixToTime: rename_func("FROM_UNIXTIME"),
+ exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
+ }
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
new file mode 100644
index 0000000..148dfb5
--- /dev/null
+++ b/sqlglot/dialects/snowflake.py
@@ -0,0 +1,145 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func
+from sqlglot.expressions import Literal
+from sqlglot.generator import Generator
+from sqlglot.helper import list_get
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+def _check_int(s):
+ if s[0] in ("-", "+"):
+ return s[1:].isdigit()
+ return s.isdigit()
+
+
+# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
+def _snowflake_to_timestamp(args):
+ if len(args) == 2:
+ first_arg, second_arg = args
+ if second_arg.is_string:
+ # case: <string_expr> [ , <format> ]
+ return format_time_lambda(exp.StrToTime, "snowflake")(args)
+
+ # case: <numeric_expr> [ , <scale> ]
+ if second_arg.name not in ["0", "3", "9"]:
+ raise ValueError(
+ f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9"
+ )
+
+ if second_arg.name == "0":
+ timescale = exp.UnixToTime.SECONDS
+ elif second_arg.name == "3":
+ timescale = exp.UnixToTime.MILLIS
+ elif second_arg.name == "9":
+ timescale = exp.UnixToTime.MICROS
+
+ return exp.UnixToTime(this=first_arg, scale=timescale)
+
+ first_arg = list_get(args, 0)
+ if not isinstance(first_arg, Literal):
+ # case: <variant_expr>
+ return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
+
+ if first_arg.is_string:
+ if _check_int(first_arg.this):
+ # case: <integer>
+ return exp.UnixToTime.from_arg_list(args)
+
+ # case: <date_expr>
+ return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args)
+
+ # case: <numeric_expr>
+ return exp.UnixToTime.from_arg_list(args)
+
+
+def _unix_to_time(self, expression):
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale in [None, exp.UnixToTime.SECONDS]:
+ return f"TO_TIMESTAMP({timestamp})"
+ if scale == exp.UnixToTime.MILLIS:
+ return f"TO_TIMESTAMP({timestamp}, 3)"
+ if scale == exp.UnixToTime.MICROS:
+ return f"TO_TIMESTAMP({timestamp}, 9)"
+
+ raise ValueError("Improper scale for timestamp")
+
+
+class Snowflake(Dialect):
+ null_ordering = "nulls_are_large"
+ time_format = "'yyyy-mm-dd hh24:mi:ss'"
+
+ time_mapping = {
+ "YYYY": "%Y",
+ "yyyy": "%Y",
+ "YY": "%y",
+ "yy": "%y",
+ "MMMM": "%B",
+ "mmmm": "%B",
+ "MON": "%b",
+ "mon": "%b",
+ "MM": "%m",
+ "mm": "%m",
+ "DD": "%d",
+ "dd": "%d",
+ "d": "%-d",
+ "DY": "%w",
+ "dy": "%w",
+ "HH24": "%H",
+ "hh24": "%H",
+ "HH12": "%I",
+ "hh12": "%I",
+ "MI": "%M",
+ "mi": "%M",
+ "SS": "%S",
+ "ss": "%S",
+ "FF": "%f",
+ "ff": "%f",
+ "FF6": "%f",
+ "ff6": "%f",
+ }
+
+ class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "ARRAYAGG": exp.ArrayAgg.from_arg_list,
+ "IFF": exp.If.from_arg_list,
+ "TO_TIMESTAMP": _snowflake_to_timestamp,
+ }
+
+ COLUMN_OPERATORS = {
+ **Parser.COLUMN_OPERATORS,
+ TokenType.COLON: lambda self, this, path: self.expression(
+ exp.Bracket,
+ this=this,
+ expressions=[path],
+ ),
+ }
+
+ class Tokenizer(Tokenizer):
+ QUOTES = ["'", "$$"]
+ ESCAPE = "\\"
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "QUALIFY": TokenType.QUALIFY,
+ "DOUBLE PRECISION": TokenType.DOUBLE,
+ }
+
+ class Generator(Generator):
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.If: rename_func("IFF"),
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.UnixToTime: _unix_to_time,
+ }
+
+ def except_op(self, expression):
+ if not expression.args.get("distinct", False):
+ self.unsupported("EXCEPT with All is not supported in Snowflake")
+ return super().except_op(expression)
+
+ def intersect_op(self, expression):
+ if not expression.args.get("distinct", False):
+ self.unsupported("INTERSECT with All is not supported in Snowflake")
+ return super().intersect_op(expression)
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
new file mode 100644
index 0000000..89c7ed5
--- /dev/null
+++ b/sqlglot/dialects/spark.py
@@ -0,0 +1,106 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import no_ilike_sql, rename_func
+from sqlglot.dialects.hive import Hive, HiveMap
+from sqlglot.helper import list_get
+
+
+def _create_sql(self, e):
+ kind = e.args.get("kind")
+ temporary = e.args.get("temporary")
+
+ if kind.upper() == "TABLE" and temporary is True:
+ return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}"
+ return self.create_sql(e)
+
+
+def _map_sql(self, expression):
+ keys = self.sql(expression.args["keys"])
+ values = self.sql(expression.args["values"])
+ return f"MAP_FROM_ARRAYS({keys}, {values})"
+
+
+def _str_to_date(self, expression):
+ this = self.sql(expression, "this")
+ time_format = self.format_time(expression)
+ if time_format == Hive.date_format:
+ return f"TO_DATE({this})"
+ return f"TO_DATE({this}, {time_format})"
+
+
+def _unix_to_time(self, expression):
+ scale = expression.args.get("scale")
+ timestamp = self.sql(expression, "this")
+ if scale is None:
+ return f"FROM_UNIXTIME({timestamp})"
+ if scale == exp.UnixToTime.SECONDS:
+ return f"TIMESTAMP_SECONDS({timestamp})"
+ if scale == exp.UnixToTime.MILLIS:
+ return f"TIMESTAMP_MILLIS({timestamp})"
+ if scale == exp.UnixToTime.MICROS:
+ return f"TIMESTAMP_MICROS({timestamp})"
+
+ raise ValueError("Improper scale for timestamp")
+
+
+class Spark(Hive):
+ class Parser(Hive.Parser):
+ FUNCTIONS = {
+ **Hive.Parser.FUNCTIONS,
+ "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
+ "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "LEFT": lambda args: exp.Substring(
+ this=list_get(args, 0),
+ start=exp.Literal.number(1),
+ length=list_get(args, 1),
+ ),
+ "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
+ this=list_get(args, 0),
+ expression=list_get(args, 1),
+ ),
+ "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
+ this=list_get(args, 0),
+ expression=list_get(args, 1),
+ ),
+ "RIGHT": lambda args: exp.Substring(
+ this=list_get(args, 0),
+ start=exp.Sub(
+ this=exp.Length(this=list_get(args, 0)),
+ expression=exp.Add(
+ this=list_get(args, 1), expression=exp.Literal.number(1)
+ ),
+ ),
+ length=list_get(args, 1),
+ ),
+ }
+
+ class Generator(Hive.Generator):
+ TYPE_MAPPING = {
+ **Hive.Generator.TYPE_MAPPING,
+ exp.DataType.Type.TINYINT: "BYTE",
+ exp.DataType.Type.SMALLINT: "SHORT",
+ exp.DataType.Type.BIGINT: "LONG",
+ }
+
+ TRANSFORMS = {
+ **{
+ k: v
+ for k, v in Hive.Generator.TRANSFORMS.items()
+ if k not in {exp.ArraySort}
+ },
+ exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
+ exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
+ exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
+ exp.ILike: no_ilike_sql,
+ exp.StrToDate: _str_to_date,
+ exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.UnixToTime: _unix_to_time,
+ exp.Create: _create_sql,
+ exp.Map: _map_sql,
+ exp.Reduce: rename_func("AGGREGATE"),
+ exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
+ HiveMap: _map_sql,
+ }
+
+ def bitstring_sql(self, expression):
+ return f"X'{self.sql(expression, 'this')}'"
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
new file mode 100644
index 0000000..6cf5022
--- /dev/null
+++ b/sqlglot/dialects/sqlite.py
@@ -0,0 +1,63 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import (
+ Dialect,
+ arrow_json_extract_scalar_sql,
+ arrow_json_extract_sql,
+ no_ilike_sql,
+ no_tablesample_sql,
+ no_trycast_sql,
+ rename_func,
+)
+from sqlglot.generator import Generator
+from sqlglot.parser import Parser
+from sqlglot.tokens import Tokenizer, TokenType
+
+
+class SQLite(Dialect):
+ class Tokenizer(Tokenizer):
+ IDENTIFIERS = ['"', ("[", "]"), "`"]
+
+ KEYWORDS = {
+ **Tokenizer.KEYWORDS,
+ "AUTOINCREMENT": TokenType.AUTO_INCREMENT,
+ }
+
+ class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "EDITDIST3": exp.Levenshtein.from_arg_list,
+ }
+
+ class Generator(Generator):
+ TYPE_MAPPING = {
+ **Generator.TYPE_MAPPING,
+ exp.DataType.Type.BOOLEAN: "INTEGER",
+ exp.DataType.Type.TINYINT: "INTEGER",
+ exp.DataType.Type.SMALLINT: "INTEGER",
+ exp.DataType.Type.INT: "INTEGER",
+ exp.DataType.Type.BIGINT: "INTEGER",
+ exp.DataType.Type.FLOAT: "REAL",
+ exp.DataType.Type.DOUBLE: "REAL",
+ exp.DataType.Type.DECIMAL: "REAL",
+ exp.DataType.Type.CHAR: "TEXT",
+ exp.DataType.Type.NCHAR: "TEXT",
+ exp.DataType.Type.VARCHAR: "TEXT",
+ exp.DataType.Type.NVARCHAR: "TEXT",
+ exp.DataType.Type.BINARY: "BLOB",
+ }
+
+ TOKEN_MAPPING = {
+ TokenType.AUTO_INCREMENT: "AUTOINCREMENT",
+ }
+
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.ILike: no_ilike_sql,
+ exp.JSONExtract: arrow_json_extract_sql,
+ exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
+ exp.JSONBExtract: arrow_json_extract_sql,
+ exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
+ exp.Levenshtein: rename_func("EDITDIST3"),
+ exp.TableSample: no_tablesample_sql,
+ exp.TryCast: no_trycast_sql,
+ }
diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py
new file mode 100644
index 0000000..b9cd584
--- /dev/null
+++ b/sqlglot/dialects/starrocks.py
@@ -0,0 +1,12 @@
+from sqlglot import exp
+from sqlglot.dialects.mysql import MySQL
+
+
+class StarRocks(MySQL):
+ class Generator(MySQL.Generator):
+ TYPE_MAPPING = {
+ **MySQL.Generator.TYPE_MAPPING,
+ exp.DataType.Type.TEXT: "STRING",
+ exp.DataType.Type.TIMESTAMP: "DATETIME",
+ exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
+ }
diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py
new file mode 100644
index 0000000..e571749
--- /dev/null
+++ b/sqlglot/dialects/tableau.py
@@ -0,0 +1,37 @@
+from sqlglot import exp
+from sqlglot.dialects.dialect import Dialect
+from sqlglot.generator import Generator
+from sqlglot.helper import list_get
+from sqlglot.parser import Parser
+
+
+def _if_sql(self, expression):
+ return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END"
+
+
+def _coalesce_sql(self, expression):
+ return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})"
+
+
+def _count_sql(self, expression):
+ this = expression.this
+ if isinstance(this, exp.Distinct):
+ return f"COUNTD({self.sql(this, 'this')})"
+ return f"COUNT({self.sql(expression, 'this')})"
+
+
+class Tableau(Dialect):
+ class Generator(Generator):
+ TRANSFORMS = {
+ **Generator.TRANSFORMS,
+ exp.If: _if_sql,
+ exp.Coalesce: _coalesce_sql,
+ exp.Count: _count_sql,
+ }
+
+ class Parser(Parser):
+ FUNCTIONS = {
+ **Parser.FUNCTIONS,
+ "IFNULL": exp.Coalesce.from_arg_list,
+ "COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))),
+ }
diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py
new file mode 100644
index 0000000..805106c
--- /dev/null
+++ b/sqlglot/dialects/trino.py
@@ -0,0 +1,10 @@
+from sqlglot import exp
+from sqlglot.dialects.presto import Presto
+
+
+class Trino(Presto):
+ class Generator(Presto.Generator):
+ TRANSFORMS = {
+ **Presto.Generator.TRANSFORMS,
+ exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
+ }
diff --git a/sqlglot/diff.py b/sqlglot/diff.py
new file mode 100644
index 0000000..8eeb4e9
--- /dev/null
+++ b/sqlglot/diff.py
@@ -0,0 +1,314 @@
+from collections import defaultdict
+from dataclasses import dataclass
+from heapq import heappop, heappush
+
+from sqlglot import Dialect
+from sqlglot import expressions as exp
+from sqlglot.helper import ensure_list
+
+
+@dataclass(frozen=True)
+class Insert:
+ """Indicates that a new node has been inserted"""
+
+ expression: exp.Expression
+
+
+@dataclass(frozen=True)
+class Remove:
+ """Indicates that an existing node has been removed"""
+
+ expression: exp.Expression
+
+
+@dataclass(frozen=True)
+class Move:
+ """Indicates that an existing node's position within the tree has changed"""
+
+ expression: exp.Expression
+
+
+@dataclass(frozen=True)
+class Update:
+ """Indicates that an existing node has been updated"""
+
+ source: exp.Expression
+ target: exp.Expression
+
+
+@dataclass(frozen=True)
+class Keep:
+ """Indicates that an existing node hasn't been changed"""
+
+ source: exp.Expression
+ target: exp.Expression
+
+
+def diff(source, target):
+ """
+ Returns the list of changes between the source and the target expressions.
+
+ Examples:
+ >>> diff(parse_one("a + b"), parse_one("a + c"))
+ [
+ Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))),
+ Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))),
+ Keep(
+ source=(ADD this: ...),
+ target=(ADD this: ...)
+ ),
+ Keep(
+ source=(COLUMN this: (IDENTIFIER this: a, quoted: False)),
+ target=(COLUMN this: (IDENTIFIER this: a, quoted: False))
+ ),
+ ]
+
+ Args:
+ source (sqlglot.Expression): the source expression.
+ target (sqlglot.Expression): the target expression against which the diff should be calculated.
+
+ Returns:
+ the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees.
+ This list represents a sequence of steps needed to transform the source expression tree into the target one.
+ """
+ return ChangeDistiller().diff(source.copy(), target.copy())
+
+
+LEAF_EXPRESSION_TYPES = (
+ exp.Boolean,
+ exp.DataType,
+ exp.Identifier,
+ exp.Literal,
+)
+
+
+class ChangeDistiller:
+ """
+ The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in
+ their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by
+ Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf.
+ """
+
+ def __init__(self, f=0.6, t=0.6):
+ self.f = f
+ self.t = t
+ self._sql_generator = Dialect().generator()
+
+ def diff(self, source, target):
+ self._source = source
+ self._target = target
+ self._source_index = {id(n[0]): n[0] for n in source.bfs()}
+ self._target_index = {id(n[0]): n[0] for n in target.bfs()}
+ self._unmatched_source_nodes = set(self._source_index)
+ self._unmatched_target_nodes = set(self._target_index)
+ self._bigram_histo_cache = {}
+
+ matching_set = self._compute_matching_set()
+ return self._generate_edit_script(matching_set)
+
+ def _generate_edit_script(self, matching_set):
+ edit_script = []
+ for removed_node_id in self._unmatched_source_nodes:
+ edit_script.append(Remove(self._source_index[removed_node_id]))
+ for inserted_node_id in self._unmatched_target_nodes:
+ edit_script.append(Insert(self._target_index[inserted_node_id]))
+ for kept_source_node_id, kept_target_node_id in matching_set:
+ source_node = self._source_index[kept_source_node_id]
+ target_node = self._target_index[kept_target_node_id]
+ if (
+ not isinstance(source_node, LEAF_EXPRESSION_TYPES)
+ or source_node == target_node
+ ):
+ edit_script.extend(
+ self._generate_move_edits(source_node, target_node, matching_set)
+ )
+ edit_script.append(Keep(source_node, target_node))
+ else:
+ edit_script.append(Update(source_node, target_node))
+
+ return edit_script
+
+ def _generate_move_edits(self, source, target, matching_set):
+ source_args = [id(e) for e in _expression_only_args(source)]
+ target_args = [id(e) for e in _expression_only_args(target)]
+
+ args_lcs = set(
+ _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set)
+ )
+
+ move_edits = []
+ for a in source_args:
+ if a not in args_lcs and a not in self._unmatched_source_nodes:
+ move_edits.append(Move(self._source_index[a]))
+
+ return move_edits
+
+ def _compute_matching_set(self):
+ leaves_matching_set = self._compute_leaf_matching_set()
+ matching_set = leaves_matching_set.copy()
+
+ ordered_unmatched_source_nodes = {
+ id(n[0]): None
+ for n in self._source.bfs()
+ if id(n[0]) in self._unmatched_source_nodes
+ }
+ ordered_unmatched_target_nodes = {
+ id(n[0]): None
+ for n in self._target.bfs()
+ if id(n[0]) in self._unmatched_target_nodes
+ }
+
+ for source_node_id in ordered_unmatched_source_nodes:
+ for target_node_id in ordered_unmatched_target_nodes:
+ source_node = self._source_index[source_node_id]
+ target_node = self._target_index[target_node_id]
+ if _is_same_type(source_node, target_node):
+ source_leaf_ids = {id(l) for l in _get_leaves(source_node)}
+ target_leaf_ids = {id(l) for l in _get_leaves(target_node)}
+
+ max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids))
+ if max_leaves_num:
+ common_leaves_num = sum(
+ 1 if s in source_leaf_ids and t in target_leaf_ids else 0
+ for s, t in leaves_matching_set
+ )
+ leaf_similarity_score = common_leaves_num / max_leaves_num
+ else:
+ leaf_similarity_score = 0.0
+
+ adjusted_t = (
+ self.t
+ if min(len(source_leaf_ids), len(target_leaf_ids)) > 4
+ else 0.4
+ )
+
+ if leaf_similarity_score >= 0.8 or (
+ leaf_similarity_score >= adjusted_t
+ and self._dice_coefficient(source_node, target_node) >= self.f
+ ):
+ matching_set.add((source_node_id, target_node_id))
+ self._unmatched_source_nodes.remove(source_node_id)
+ self._unmatched_target_nodes.remove(target_node_id)
+ ordered_unmatched_target_nodes.pop(target_node_id, None)
+ break
+
+ return matching_set
+
+ def _compute_leaf_matching_set(self):
+ candidate_matchings = []
+ source_leaves = list(_get_leaves(self._source))
+ target_leaves = list(_get_leaves(self._target))
+ for source_leaf in source_leaves:
+ for target_leaf in target_leaves:
+ if _is_same_type(source_leaf, target_leaf):
+ similarity_score = self._dice_coefficient(source_leaf, target_leaf)
+ if similarity_score >= self.f:
+ heappush(
+ candidate_matchings,
+ (
+ -similarity_score,
+ len(candidate_matchings),
+ source_leaf,
+ target_leaf,
+ ),
+ )
+
+ # Pick best matchings based on the highest score
+ matching_set = set()
+ while candidate_matchings:
+ _, _, source_leaf, target_leaf = heappop(candidate_matchings)
+ if (
+ id(source_leaf) in self._unmatched_source_nodes
+ and id(target_leaf) in self._unmatched_target_nodes
+ ):
+ matching_set.add((id(source_leaf), id(target_leaf)))
+ self._unmatched_source_nodes.remove(id(source_leaf))
+ self._unmatched_target_nodes.remove(id(target_leaf))
+
+ return matching_set
+
+ def _dice_coefficient(self, source, target):
+ source_histo = self._bigram_histo(source)
+ target_histo = self._bigram_histo(target)
+
+ total_grams = sum(source_histo.values()) + sum(target_histo.values())
+ if not total_grams:
+ return 1.0 if source == target else 0.0
+
+ overlap_len = 0
+ overlapping_grams = set(source_histo) & set(target_histo)
+ for g in overlapping_grams:
+ overlap_len += min(source_histo[g], target_histo[g])
+
+ return 2 * overlap_len / total_grams
+
+ def _bigram_histo(self, expression):
+ if id(expression) in self._bigram_histo_cache:
+ return self._bigram_histo_cache[id(expression)]
+
+ expression_str = self._sql_generator.generate(expression)
+ count = max(0, len(expression_str) - 1)
+ bigram_histo = defaultdict(int)
+ for i in range(count):
+ bigram_histo[expression_str[i : i + 2]] += 1
+
+ self._bigram_histo_cache[id(expression)] = bigram_histo
+ return bigram_histo
+
+
+def _get_leaves(expression):
+ has_child_exprs = False
+
+ for a in expression.args.values():
+ nodes = ensure_list(a)
+ for node in nodes:
+ if isinstance(node, exp.Expression):
+ has_child_exprs = True
+ yield from _get_leaves(node)
+
+ if not has_child_exprs:
+ yield expression
+
+
+def _is_same_type(source, target):
+ if type(source) is type(target):
+ if isinstance(source, exp.Join):
+ return source.args.get("side") == target.args.get("side")
+
+ if isinstance(source, exp.Anonymous):
+ return source.this == target.this
+
+ return True
+
+ return False
+
+
+def _expression_only_args(expression):
+ args = []
+ if expression:
+ for a in expression.args.values():
+ args.extend(ensure_list(a))
+ return [a for a in args if isinstance(a, exp.Expression)]
+
+
+def _lcs(seq_a, seq_b, equal):
+ """Calculates the longest common subsequence"""
+
+ len_a = len(seq_a)
+ len_b = len(seq_b)
+ lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)]
+
+ for i in range(len_a + 1):
+ for j in range(len_b + 1):
+ if i == 0 or j == 0:
+ lcs_result[i][j] = []
+ elif equal(seq_a[i - 1], seq_b[j - 1]):
+ lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]]
+ else:
+ lcs_result[i][j] = (
+ lcs_result[i - 1][j]
+ if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1])
+ else lcs_result[i][j - 1]
+ )
+
+ return lcs_result[len_a][len_b]
diff --git a/sqlglot/errors.py b/sqlglot/errors.py
new file mode 100644
index 0000000..89aa935
--- /dev/null
+++ b/sqlglot/errors.py
@@ -0,0 +1,38 @@
+from enum import auto
+
+from sqlglot.helper import AutoName
+
+
+class ErrorLevel(AutoName):
+ IGNORE = auto() # Ignore any parser errors
+ WARN = auto() # Log any parser errors with ERROR level
+ RAISE = auto() # Collect all parser errors and raise a single exception
+ IMMEDIATE = auto() # Immediately raise an exception on the first parser error
+
+
+class SqlglotError(Exception):
+ pass
+
+
+class UnsupportedError(SqlglotError):
+ pass
+
+
+class ParseError(SqlglotError):
+ pass
+
+
+class TokenError(SqlglotError):
+ pass
+
+
+class OptimizeError(SqlglotError):
+ pass
+
+
+def concat_errors(errors, maximum):
+ msg = [str(e) for e in errors[:maximum]]
+ remaining = len(errors) - maximum
+ if remaining > 0:
+ msg.append(f"... and {remaining} more")
+ return "\n\n".join(msg)
diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py
new file mode 100644
index 0000000..a437431
--- /dev/null
+++ b/sqlglot/executor/__init__.py
@@ -0,0 +1,39 @@
+import logging
+import time
+
+from sqlglot import parse_one
+from sqlglot.executor.python import PythonExecutor
+from sqlglot.optimizer import optimize
+from sqlglot.planner import Plan
+
+logger = logging.getLogger("sqlglot")
+
+
+def execute(sql, schema, read=None):
+ """
+ Run a sql query against data.
+
+ Args:
+ sql (str): a sql statement
+ schema (dict|sqlglot.optimizer.Schema): database schema.
+ This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
+ the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ read (str): the SQL dialect to apply during parsing
+ (eg. "spark", "hive", "presto", "mysql").
+ Returns:
+ sqlglot.executor.Table: Simple columnar data structure.
+ """
+ expression = parse_one(sql, read=read)
+ now = time.time()
+ expression = optimize(expression, schema)
+ logger.debug("Optimization finished: %f", time.time() - now)
+ logger.debug("Optimized SQL: %s", expression.sql(pretty=True))
+ plan = Plan(expression)
+ logger.debug("Logical Plan: %s", plan)
+ now = time.time()
+ result = PythonExecutor().execute(plan)
+ logger.debug("Query finished: %f", time.time() - now)
+ return result
diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py
new file mode 100644
index 0000000..457bea7
--- /dev/null
+++ b/sqlglot/executor/context.py
@@ -0,0 +1,68 @@
+from sqlglot.executor.env import ENV
+
+
+class Context:
+ """
+ Execution context for sql expressions.
+
+ Context is used to hold relevant data tables which can then be queried on with eval.
+
+ References to columns can either be scalar or vectors. When set_row is used, column references
+ evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient
+ evaluation of aggregation functions.
+ """
+
+ def __init__(self, tables, env=None):
+ """
+ Args
+ tables (dict): table_name -> Table, representing the scope of the current execution context
+ env (Optional[dict]): dictionary of functions within the execution context
+ """
+ self.tables = tables
+ self.range_readers = {
+ name: table.range_reader for name, table in self.tables.items()
+ }
+ self.row_readers = {name: table.reader for name, table in tables.items()}
+ self.env = {**(env or {}), "scope": self.row_readers}
+
+ def eval(self, code):
+ return eval(code, ENV, self.env)
+
+ def eval_tuple(self, codes):
+ return tuple(self.eval(code) for code in codes)
+
+ def __iter__(self):
+ return self.table_iter(list(self.tables)[0])
+
+ def table_iter(self, table):
+ self.env["scope"] = self.row_readers
+
+ for reader in self.tables[table]:
+ yield reader, self
+
+ def sort(self, table, key):
+ table = self.tables[table]
+
+ def sort_key(row):
+ table.reader.row = row
+ return self.eval_tuple(key)
+
+ table.rows.sort(key=sort_key)
+
+ def set_row(self, table, row):
+ self.row_readers[table].row = row
+ self.env["scope"] = self.row_readers
+
+ def set_index(self, table, index):
+ self.row_readers[table].row = self.tables[table].rows[index]
+ self.env["scope"] = self.row_readers
+
+ def set_range(self, table, start, end):
+ self.range_readers[table].range = range(start, end)
+ self.env["scope"] = self.range_readers
+
+ def __getitem__(self, table):
+ return self.env["scope"][table]
+
+ def __contains__(self, table):
+ return table in self.tables
diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py
new file mode 100644
index 0000000..72b0558
--- /dev/null
+++ b/sqlglot/executor/env.py
@@ -0,0 +1,32 @@
+import datetime
+import re
+import statistics
+
+
+class reverse_key:
+ def __init__(self, obj):
+ self.obj = obj
+
+ def __eq__(self, other):
+ return other.obj == self.obj
+
+ def __lt__(self, other):
+ return other.obj < self.obj
+
+
+ENV = {
+ "__builtins__": {},
+ "datetime": datetime,
+ "locals": locals,
+ "re": re,
+ "float": float,
+ "int": int,
+ "str": str,
+ "desc": reverse_key,
+ "SUM": sum,
+ "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean,
+ "COUNT": lambda acc: sum(1 for e in acc if e is not None),
+ "MAX": max,
+ "MIN": min,
+ "POW": pow,
+}
diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py
new file mode 100644
index 0000000..388a419
--- /dev/null
+++ b/sqlglot/executor/python.py
@@ -0,0 +1,360 @@
+import ast
+import collections
+import itertools
+
+from sqlglot import exp, planner
+from sqlglot.dialects.dialect import Dialect, inline_array_sql
+from sqlglot.executor.context import Context
+from sqlglot.executor.env import ENV
+from sqlglot.executor.table import Table
+from sqlglot.generator import Generator
+from sqlglot.helper import csv_reader
+from sqlglot.tokens import Tokenizer
+
+
+class PythonExecutor:
+ def __init__(self, env=None):
+ self.generator = Python().generator(identify=True)
+ self.env = {**ENV, **(env or {})}
+
+ def execute(self, plan):
+ running = set()
+ finished = set()
+ queue = set(plan.leaves)
+ contexts = {}
+
+ while queue:
+ node = queue.pop()
+ context = self.context(
+ {
+ name: table
+ for dep in node.dependencies
+ for name, table in contexts[dep].tables.items()
+ }
+ )
+ running.add(node)
+
+ if isinstance(node, planner.Scan):
+ contexts[node] = self.scan(node, context)
+ elif isinstance(node, planner.Aggregate):
+ contexts[node] = self.aggregate(node, context)
+ elif isinstance(node, planner.Join):
+ contexts[node] = self.join(node, context)
+ elif isinstance(node, planner.Sort):
+ contexts[node] = self.sort(node, context)
+ else:
+ raise NotImplementedError
+
+ running.remove(node)
+ finished.add(node)
+
+ for dep in node.dependents:
+ if dep not in running and all(d in contexts for d in dep.dependencies):
+ queue.add(dep)
+
+ for dep in node.dependencies:
+ if all(d in finished for d in dep.dependents):
+ contexts.pop(dep)
+
+ root = plan.root
+ return contexts[root].tables[root.name]
+
+ def generate(self, expression):
+ """Convert a SQL expression into literal Python code and compile it into bytecode."""
+ if not expression:
+ return None
+
+ sql = self.generator.generate(expression)
+ return compile(sql, sql, "eval", optimize=2)
+
+ def generate_tuple(self, expressions):
+ """Convert an array of SQL expressions into tuple of Python byte code."""
+ if not expressions:
+ return tuple()
+ return tuple(self.generate(expression) for expression in expressions)
+
+ def context(self, tables):
+ return Context(tables, env=self.env)
+
+ def table(self, expressions):
+ return Table(expression.alias_or_name for expression in expressions)
+
+ def scan(self, step, context):
+ if hasattr(step, "source"):
+ source = step.source
+
+ if isinstance(source, exp.Expression):
+ source = source.this.name or source.alias
+ else:
+ source = step.name
+ condition = self.generate(step.condition)
+ projections = self.generate_tuple(step.projections)
+
+ if source in context:
+ if not projections and not condition:
+ return self.context({step.name: context.tables[source]})
+ table_iter = context.table_iter(source)
+ else:
+ table_iter = self.scan_csv(step)
+
+ if projections:
+ sink = self.table(step.projections)
+ elif source in context:
+ sink = Table(context[source].columns)
+ else:
+ sink = None
+
+ for reader, ctx in table_iter:
+ if sink is None:
+ sink = Table(ctx[source].columns)
+
+ if condition and not ctx.eval(condition):
+ continue
+
+ if projections:
+ sink.append(ctx.eval_tuple(projections))
+ else:
+ sink.append(reader.row)
+
+ if len(sink) >= step.limit:
+ break
+
+ return self.context({step.name: sink})
+
+ def scan_csv(self, step):
+ source = step.source
+ alias = source.alias
+
+ with csv_reader(source.this) as reader:
+ columns = next(reader)
+ table = Table(columns)
+ context = self.context({alias: table})
+ types = []
+
+ for row in reader:
+ if not types:
+ for v in row:
+ try:
+ types.append(type(ast.literal_eval(v)))
+ except (ValueError, SyntaxError):
+ types.append(str)
+ context.set_row(alias, tuple(t(v) for t, v in zip(types, row)))
+ yield context[alias], context
+
+ def join(self, step, context):
+ source = step.name
+
+ join_context = self.context({source: context.tables[source]})
+
+ def merge_context(ctx, table):
+ # create a new context where all existing tables are mapped to a new one
+ return self.context({name: table for name in ctx.tables})
+
+ for name, join in step.joins.items():
+ join_context = self.context(
+ {**join_context.tables, name: context.tables[name]}
+ )
+
+ if join.get("source_key"):
+ table = self.hash_join(join, source, name, join_context)
+ else:
+ table = self.nested_loop_join(join, source, name, join_context)
+
+ join_context = merge_context(join_context, table)
+
+ # apply projections or conditions
+ context = self.scan(step, join_context)
+
+ # use the scan context since it returns a single table
+ # otherwise there are no projections so all other tables are still in scope
+ if step.projections:
+ return context
+
+ return merge_context(join_context, context.tables[source])
+
+ def nested_loop_join(self, _join, a, b, context):
+ table = Table(context.tables[a].columns + context.tables[b].columns)
+
+ for reader_a, _ in context.table_iter(a):
+ for reader_b, _ in context.table_iter(b):
+ table.append(reader_a.row + reader_b.row)
+
+ return table
+
+ def hash_join(self, join, a, b, context):
+ a_key = self.generate_tuple(join["source_key"])
+ b_key = self.generate_tuple(join["join_key"])
+
+ results = collections.defaultdict(lambda: ([], []))
+
+ for reader, ctx in context.table_iter(a):
+ results[ctx.eval_tuple(a_key)][0].append(reader.row)
+ for reader, ctx in context.table_iter(b):
+ results[ctx.eval_tuple(b_key)][1].append(reader.row)
+
+ table = Table(context.tables[a].columns + context.tables[b].columns)
+ for a_group, b_group in results.values():
+ for a_row, b_row in itertools.product(a_group, b_group):
+ table.append(a_row + b_row)
+
+ return table
+
+ def sort_merge_join(self, join, a, b, context):
+ a_key = self.generate_tuple(join["source_key"])
+ b_key = self.generate_tuple(join["join_key"])
+
+ context.sort(a, a_key)
+ context.sort(b, b_key)
+
+ a_i = 0
+ b_i = 0
+ a_n = len(context.tables[a])
+ b_n = len(context.tables[b])
+
+ table = Table(context.tables[a].columns + context.tables[b].columns)
+
+ def get_key(source, key, i):
+ context.set_index(source, i)
+ return context.eval_tuple(key)
+
+ while a_i < a_n and b_i < b_n:
+ key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i))
+
+ a_group = []
+
+ while a_i < a_n and key == get_key(a, a_key, a_i):
+ a_group.append(context[a].row)
+ a_i += 1
+
+ b_group = []
+
+ while b_i < b_n and key == get_key(b, b_key, b_i):
+ b_group.append(context[b].row)
+ b_i += 1
+
+ for a_row, b_row in itertools.product(a_group, b_group):
+ table.append(a_row + b_row)
+
+ return table
+
+ def aggregate(self, step, context):
+ source = step.source
+ group_by = self.generate_tuple(step.group)
+ aggregations = self.generate_tuple(step.aggregations)
+ operands = self.generate_tuple(step.operands)
+
+ context.sort(source, group_by)
+
+ if step.operands:
+ source_table = context.tables[source]
+ operand_table = Table(
+ source_table.columns + self.table(step.operands).columns
+ )
+
+ for reader, ctx in context:
+ operand_table.append(reader.row + ctx.eval_tuple(operands))
+
+ context = self.context({source: operand_table})
+
+ group = None
+ start = 0
+ end = 1
+ length = len(context.tables[source])
+ table = self.table(step.group + step.aggregations)
+
+ for i in range(length):
+ context.set_index(source, i)
+ key = context.eval_tuple(group_by)
+ group = key if group is None else group
+ end += 1
+
+ if i == length - 1:
+ context.set_range(source, start, end - 1)
+ elif key != group:
+ context.set_range(source, start, end - 2)
+ else:
+ continue
+
+ table.append(group + context.eval_tuple(aggregations))
+ group = key
+ start = end - 2
+
+ return self.scan(step, self.context({source: table}))
+
+ def sort(self, step, context):
+ table = list(context.tables)[0]
+ key = self.generate_tuple(step.key)
+ context.sort(table, key)
+ return self.scan(step, context)
+
+
+def _cast_py(self, expression):
+ to = expression.args["to"].this
+ this = self.sql(expression, "this")
+
+ if to == exp.DataType.Type.DATE:
+ return f"datetime.date.fromisoformat({this})"
+ if to == exp.DataType.Type.TEXT:
+ return f"str({this})"
+ raise NotImplementedError
+
+
+def _column_py(self, expression):
+ table = self.sql(expression, "table")
+ this = self.sql(expression, "this")
+ return f"scope[{table}][{this}]"
+
+
+def _interval_py(self, expression):
+ this = self.sql(expression, "this")
+ unit = expression.text("unit").upper()
+ if unit == "DAY":
+ return f"datetime.timedelta(days=float({this}))"
+ raise NotImplementedError
+
+
+def _like_py(self, expression):
+ this = self.sql(expression, "this")
+ expression = self.sql(expression, "expression")
+ return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})"""
+
+
+def _ordered_py(self, expression):
+ this = self.sql(expression, "this")
+ desc = expression.args.get("desc")
+ return f"desc({this})" if desc else this
+
+
+class Python(Dialect):
+ class Tokenizer(Tokenizer):
+ ESCAPE = "\\"
+
+ class Generator(Generator):
+ TRANSFORMS = {
+ exp.Alias: lambda self, e: self.sql(e.this),
+ exp.Array: inline_array_sql,
+ exp.And: lambda self, e: self.binary(e, "and"),
+ exp.Cast: _cast_py,
+ exp.Column: _column_py,
+ exp.EQ: lambda self, e: self.binary(e, "=="),
+ exp.Interval: _interval_py,
+ exp.Is: lambda self, e: self.binary(e, "is"),
+ exp.Like: _like_py,
+ exp.Not: lambda self, e: f"not {self.sql(e.this)}",
+ exp.Null: lambda *_: "None",
+ exp.Or: lambda self, e: self.binary(e, "or"),
+ exp.Ordered: _ordered_py,
+ exp.Star: lambda *_: "1",
+ }
+
+ def case_sql(self, expression):
+ this = self.sql(expression, "this")
+ chain = self.sql(expression, "default") or "None"
+
+ for e in reversed(expression.args["ifs"]):
+ true = self.sql(e, "true")
+ condition = self.sql(e, "this")
+ condition = f"{this} = ({condition})" if this else condition
+ chain = f"{true} if {condition} else ({chain})"
+
+ return chain
diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py
new file mode 100644
index 0000000..6df49f7
--- /dev/null
+++ b/sqlglot/executor/table.py
@@ -0,0 +1,81 @@
+class Table:
+ def __init__(self, *columns, rows=None):
+ self.columns = tuple(columns if isinstance(columns[0], str) else columns[0])
+ self.rows = rows or []
+ if rows:
+ assert len(rows[0]) == len(self.columns)
+ self.reader = RowReader(self.columns)
+ self.range_reader = RangeReader(self)
+
+ def append(self, row):
+ assert len(row) == len(self.columns)
+ self.rows.append(row)
+
+ def pop(self):
+ self.rows.pop()
+
+ @property
+ def width(self):
+ return len(self.columns)
+
+ def __len__(self):
+ return len(self.rows)
+
+ def __iter__(self):
+ return TableIter(self)
+
+ def __getitem__(self, index):
+ self.reader.row = self.rows[index]
+ return self.reader
+
+ def __repr__(self):
+ widths = {column: len(column) for column in self.columns}
+ lines = [" ".join(column for column in self.columns)]
+
+ for i, row in enumerate(self):
+ if i > 10:
+ break
+
+ lines.append(
+ " ".join(
+ str(row[column]).rjust(widths[column])[0 : widths[column]]
+ for column in self.columns
+ )
+ )
+ return "\n".join(lines)
+
+
+class TableIter:
+ def __init__(self, table):
+ self.table = table
+ self.index = -1
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ self.index += 1
+ if self.index < len(self.table):
+ return self.table[self.index]
+ raise StopIteration
+
+
+class RangeReader:
+ def __init__(self, table):
+ self.table = table
+ self.range = range(0)
+
+ def __len__(self):
+ return len(self.range)
+
+ def __getitem__(self, column):
+ return (self.table[i][column] for i in self.range)
+
+
+class RowReader:
+ def __init__(self, columns):
+ self.columns = {column: i for i, column in enumerate(columns)}
+ self.row = None
+
+ def __getitem__(self, column):
+ return self.row[self.columns[column]]
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
new file mode 100644
index 0000000..7acc63d
--- /dev/null
+++ b/sqlglot/expressions.py
@@ -0,0 +1,2945 @@
+import inspect
+import re
+import sys
+from collections import deque
+from copy import deepcopy
+from enum import auto
+
+from sqlglot.errors import ParseError
+from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list
+
+
+class _Expression(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+ klass.key = clsname.lower()
+ return klass
+
+
+class Expression(metaclass=_Expression):
+ """
+ The base class for all expressions in a syntax tree.
+
+ Attributes:
+ arg_types (dict): determines arguments supported by this expression.
+ The key in a dictionary defines a unique key of an argument using
+ which the argument's value can be retrieved. The value is a boolean
+ flag which indicates whether the argument's value is required (True)
+ or optional (False).
+ """
+
+ key = None
+ arg_types = {"this": True}
+ __slots__ = ("args", "parent", "arg_key")
+
+ def __init__(self, **args):
+ self.args = args
+ self.parent = None
+ self.arg_key = None
+
+ for arg_key, value in self.args.items():
+ self._set_parent(arg_key, value)
+
+ def __eq__(self, other):
+ return type(self) is type(other) and _norm_args(self) == _norm_args(other)
+
+ def __hash__(self):
+ return hash(
+ (
+ self.key,
+ tuple(
+ (k, tuple(v) if isinstance(v, list) else v)
+ for k, v in _norm_args(self).items()
+ ),
+ )
+ )
+
+ @property
+ def this(self):
+ return self.args.get("this")
+
+ @property
+ def expression(self):
+ return self.args.get("expression")
+
+ @property
+ def expressions(self):
+ return self.args.get("expressions") or []
+
+ def text(self, key):
+ field = self.args.get(key)
+ if isinstance(field, str):
+ return field
+ if isinstance(field, (Identifier, Literal, Var)):
+ return field.this
+ return ""
+
+ @property
+ def is_string(self):
+ return isinstance(self, Literal) and self.args["is_string"]
+
+ @property
+ def is_number(self):
+ return isinstance(self, Literal) and not self.args["is_string"]
+
+ @property
+ def is_int(self):
+ if self.is_number:
+ try:
+ int(self.name)
+ return True
+ except ValueError:
+ pass
+ return False
+
+ @property
+ def alias(self):
+ if isinstance(self.args.get("alias"), TableAlias):
+ return self.args["alias"].name
+ return self.text("alias")
+
+ @property
+ def name(self):
+ return self.text("this")
+
+ @property
+ def alias_or_name(self):
+ return self.alias or self.name
+
+ def __deepcopy__(self, memo):
+ return self.__class__(**deepcopy(self.args))
+
+ def copy(self):
+ new = deepcopy(self)
+ for item, parent, _ in new.bfs():
+ if isinstance(item, Expression) and parent:
+ item.parent = parent
+ return new
+
+ def set(self, arg_key, value):
+ """
+ Sets `arg` to `value`.
+
+ Args:
+ arg_key (str): name of the expression arg
+ value: value to set the arg to.
+ """
+ self.args[arg_key] = value
+ self._set_parent(arg_key, value)
+
+ def _set_parent(self, arg_key, value):
+ if isinstance(value, Expression):
+ value.parent = self
+ value.arg_key = arg_key
+ elif isinstance(value, list):
+ for v in value:
+ if isinstance(v, Expression):
+ v.parent = self
+ v.arg_key = arg_key
+
+ @property
+ def depth(self):
+ """
+ Returns the depth of this tree.
+ """
+ if self.parent:
+ return self.parent.depth + 1
+ return 0
+
+ def find(self, *expression_types, bfs=True):
+ """
+ Returns the first node in this tree which matches at least one of
+ the specified types.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+
+ Returns:
+ the node which matches the criteria or None if no node matching
+ the criteria was found.
+ """
+ return next(self.find_all(*expression_types, bfs=bfs), None)
+
+ def find_all(self, *expression_types, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this tree and only
+ yields those that match at least one of the specified expression types.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+
+ Returns:
+ the generator object.
+ """
+ for expression, _, _ in self.walk(bfs=bfs):
+ if isinstance(expression, expression_types):
+ yield expression
+
+ def find_ancestor(self, *expression_types):
+ """
+ Returns a nearest parent matching expression_types.
+
+ Args:
+ expression_types (type): the expression type(s) to match.
+
+ Returns:
+ the parent node
+ """
+ ancestor = self.parent
+ while ancestor and not isinstance(ancestor, expression_types):
+ ancestor = ancestor.parent
+ return ancestor
+
+ @property
+ def parent_select(self):
+ """
+ Returns the parent select statement.
+ """
+ return self.find_ancestor(Select)
+
+ def walk(self, bfs=True):
+ """
+ Returns a generator object which visits all nodes in this tree.
+
+ Args:
+ bfs (bool): if set to True the BFS traversal order will be applied,
+ otherwise the DFS traversal will be used instead.
+
+ Returns:
+ the generator object.
+ """
+ if bfs:
+ yield from self.bfs()
+ else:
+ yield from self.dfs()
+
+ def dfs(self, parent=None, key=None, prune=None):
+ """
+ Returns a generator object which visits all nodes in this tree in
+ the DFS (Depth-first) order.
+
+ Returns:
+ the generator object.
+ """
+ parent = parent or self.parent
+ yield self, parent, key
+ if prune and prune(self, parent, key):
+ return
+
+ for k, v in self.args.items():
+ nodes = ensure_list(v)
+
+ for node in nodes:
+ if isinstance(node, Expression):
+ yield from node.dfs(self, k, prune)
+
+ def bfs(self, prune=None):
+ """
+ Returns a generator object which visits all nodes in this tree in
+ the BFS (Breadth-first) order.
+
+ Returns:
+ the generator object.
+ """
+ queue = deque([(self, self.parent, None)])
+
+ while queue:
+ item, parent, key = queue.popleft()
+
+ yield item, parent, key
+ if prune and prune(item, parent, key):
+ continue
+
+ if isinstance(item, Expression):
+ for k, v in item.args.items():
+ nodes = ensure_list(v)
+
+ for node in nodes:
+ if isinstance(node, Expression):
+ queue.append((node, item, k))
+
+ def unnest(self):
+ """
+ Returns the first non parenthesis child or self.
+ """
+ expression = self
+ while isinstance(expression, Paren):
+ expression = expression.this
+ return expression
+
+ def unnest_operands(self):
+ """
+ Returns unnested operands as a tuple.
+ """
+ return tuple(arg.unnest() for arg in self.args.values() if arg)
+
+ def flatten(self, unnest=True):
+ """
+ Returns a generator which yields child nodes who's parents are the same class.
+
+ A AND B AND C -> [A, B, C]
+ """
+ for node, _, _ in self.dfs(
+ prune=lambda n, p, *_: p and not isinstance(n, self.__class__)
+ ):
+ if not isinstance(node, self.__class__):
+ yield node.unnest() if unnest else node
+
+ def __str__(self):
+ return self.sql()
+
+ def __repr__(self):
+ return self.to_s()
+
+ def sql(self, dialect=None, **opts):
+ """
+ Returns SQL string representation of this tree.
+
+ Args
+ dialect (str): the dialect of the output SQL string
+ (eg. "spark", "hive", "presto", "mysql").
+ opts (dict): other :class:`~sqlglot.generator.Generator` options.
+
+ Returns
+ the SQL string.
+ """
+ from sqlglot.dialects import Dialect
+
+ return Dialect.get_or_raise(dialect)().generate(self, **opts)
+
+ def to_s(self, hide_missing=True, level=0):
+ indent = "" if not level else "\n"
+ indent += "".join([" "] * level)
+ left = f"({self.key.upper()} "
+
+ args = {
+ k: ", ".join(
+ v.to_s(hide_missing=hide_missing, level=level + 1)
+ if hasattr(v, "to_s")
+ else str(v)
+ for v in ensure_list(vs)
+ if v is not None
+ )
+ for k, vs in self.args.items()
+ }
+ args = {k: v for k, v in args.items() if v or not hide_missing}
+
+ right = ", ".join(f"{k}: {v}" for k, v in args.items())
+ right += ")"
+
+ return indent + left + right
+
+ def transform(self, fun, *args, copy=True, **kwargs):
+ """
+ Recursively visits all tree nodes (excluding already transformed ones)
+ and applies the given transformation function to each node.
+
+ Args:
+ fun (function): a function which takes a node as an argument and returns a
+ new transformed node or the same node without modifications.
+ copy (bool): if set to True a new tree instance is constructed, otherwise the tree is
+ modified in place.
+
+ Returns:
+ the transformed tree.
+ """
+ node = self.copy() if copy else self
+ new_node = fun(node, *args, **kwargs)
+
+ if new_node is None:
+ raise ValueError("A transformed node cannot be None")
+ if not isinstance(new_node, Expression):
+ return new_node
+ if new_node is not node:
+ new_node.parent = node.parent
+ return new_node
+
+ replace_children(
+ new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs)
+ )
+ return new_node
+
+ def replace(self, expression):
+ """
+ Swap out this expression with a new expression.
+
+ For example::
+
+ >>> tree = Select().select("x").from_("tbl")
+ >>> tree.find(Column).replace(Column(this="y"))
+ (COLUMN this: y)
+ >>> tree.sql()
+ 'SELECT y FROM tbl'
+
+ Args:
+ expression (Expression): new node
+
+ Returns :
+ the new expression or expressions
+ """
+ if not self.parent:
+ return expression
+
+ parent = self.parent
+ self.parent = None
+
+ replace_children(parent, lambda child: expression if child is self else child)
+ return expression
+
+ def assert_is(self, type_):
+ """
+ Assert that this `Expression` is an instance of `type_`.
+
+ If it is NOT an instance of `type_`, this raises an assertion error.
+ Otherwise, this returns this expression.
+
+ Examples:
+ This is useful for type security in chained expressions:
+
+ >>> import sqlglot
+ >>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql()
+ 'SELECT x, z FROM y'
+ """
+ assert isinstance(self, type_)
+ return self
+
+
+class Condition(Expression):
+ def and_(self, *expressions, dialect=None, **opts):
+ """
+ AND this condition with one or multiple expressions.
+
+ Example:
+ >>> condition("x=1").and_("y=1").sql()
+ 'x = 1 AND y = 1'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ And: the new condition.
+ """
+ return and_(self, *expressions, dialect=dialect, **opts)
+
+ def or_(self, *expressions, dialect=None, **opts):
+ """
+ OR this condition with one or multiple expressions.
+
+ Example:
+ >>> condition("x=1").or_("y=1").sql()
+ 'x = 1 OR y = 1'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ dialect (str): the dialect used to parse the input expression.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Or: the new condition.
+ """
+ return or_(self, *expressions, dialect=dialect, **opts)
+
+ def not_(self):
+ """
+ Wrap this condition with NOT.
+
+ Example:
+ >>> condition("x=1").not_().sql()
+ 'NOT x = 1'
+
+ Returns:
+ Not: the new condition.
+ """
+ return not_(self)
+
+
+class Predicate(Condition):
+ """Relationships like x = y, x > 1, x >= y."""
+
+
+class DerivedTable(Expression):
+ @property
+ def alias_column_names(self):
+ table_alias = self.args.get("alias")
+ if not table_alias:
+ return []
+ column_list = table_alias.assert_is(TableAlias).args.get("columns") or []
+ return [c.name for c in column_list]
+
+ @property
+ def selects(self):
+ alias = self.args.get("alias")
+
+ if alias:
+ return alias.columns
+ return []
+
+ @property
+ def named_selects(self):
+ return [select.alias_or_name for select in self.selects]
+
+
+class Annotation(Expression):
+ arg_types = {
+ "this": True,
+ "expression": True,
+ }
+
+
+class Cache(Expression):
+ arg_types = {
+ "with": False,
+ "this": True,
+ "lazy": False,
+ "options": False,
+ "expression": False,
+ }
+
+
+class Uncache(Expression):
+ arg_types = {"this": True, "exists": False}
+
+
+class Create(Expression):
+ arg_types = {
+ "with": False,
+ "this": True,
+ "kind": True,
+ "expression": False,
+ "exists": False,
+ "properties": False,
+ "temporary": False,
+ "replace": False,
+ "unique": False,
+ }
+
+
+class CharacterSet(Expression):
+ arg_types = {"this": True, "default": False}
+
+
+class With(Expression):
+ arg_types = {"expressions": True, "recursive": False}
+
+
+class WithinGroup(Expression):
+ arg_types = {"this": True, "expression": False}
+
+
+class CTE(DerivedTable):
+ arg_types = {"this": True, "alias": True}
+
+
+class TableAlias(Expression):
+ arg_types = {"this": False, "columns": False}
+
+ @property
+ def columns(self):
+ return self.args.get("columns") or []
+
+
+class BitString(Condition):
+ pass
+
+
+class Column(Condition):
+ arg_types = {"this": True, "table": False}
+
+ @property
+ def table(self):
+ return self.text("table")
+
+
+class ColumnDef(Expression):
+ arg_types = {
+ "this": True,
+ "kind": True,
+ "constraints": False,
+ }
+
+
+class ColumnConstraint(Expression):
+ arg_types = {"this": False, "kind": True}
+
+
+class AutoIncrementColumnConstraint(Expression):
+ pass
+
+
+class CheckColumnConstraint(Expression):
+ pass
+
+
+class CollateColumnConstraint(Expression):
+ pass
+
+
+class CommentColumnConstraint(Expression):
+ pass
+
+
+class DefaultColumnConstraint(Expression):
+ pass
+
+
+class NotNullColumnConstraint(Expression):
+ pass
+
+
+class PrimaryKeyColumnConstraint(Expression):
+ pass
+
+
+class UniqueColumnConstraint(Expression):
+ pass
+
+
+class Constraint(Expression):
+ arg_types = {"this": True, "expressions": True}
+
+
+class Delete(Expression):
+ arg_types = {"with": False, "this": True, "where": False}
+
+
+class Drop(Expression):
+ arg_types = {"this": False, "kind": False, "exists": False}
+
+
+class Filter(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
+class Check(Expression):
+ pass
+
+
+class ForeignKey(Expression):
+ arg_types = {
+ "expressions": True,
+ "reference": False,
+ "delete": False,
+ "update": False,
+ }
+
+
+class Unique(Expression):
+ arg_types = {"expressions": True}
+
+
+class From(Expression):
+ arg_types = {"expressions": True}
+
+
+class Having(Expression):
+ pass
+
+
+class Hint(Expression):
+ arg_types = {"expressions": True}
+
+
+class Identifier(Expression):
+ arg_types = {"this": True, "quoted": False}
+
+ @property
+ def quoted(self):
+ return bool(self.args.get("quoted"))
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg(
+ other.this
+ )
+
+ def __hash__(self):
+ return hash((self.key, self.this.lower()))
+
+
+class Index(Expression):
+ arg_types = {"this": False, "table": False, "where": False, "columns": False}
+
+
+class Insert(Expression):
+ arg_types = {
+ "with": False,
+ "this": True,
+ "expression": True,
+ "overwrite": False,
+ "exists": False,
+ "partition": False,
+ }
+
+
+# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html
+class Introducer(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
+class Partition(Expression):
+ pass
+
+
+class Fetch(Expression):
+ arg_types = {"direction": False, "count": True}
+
+
+class Group(Expression):
+ arg_types = {
+ "expressions": False,
+ "grouping_sets": False,
+ "cube": False,
+ "rollup": False,
+ }
+
+
+class Lambda(Expression):
+ arg_types = {"this": True, "expressions": True}
+
+
+class Limit(Expression):
+ arg_types = {"this": False, "expression": True}
+
+
+class Literal(Condition):
+ arg_types = {"this": True, "is_string": True}
+
+ def __eq__(self, other):
+ return (
+ isinstance(other, Literal)
+ and self.this == other.this
+ and self.args["is_string"] == other.args["is_string"]
+ )
+
+ def __hash__(self):
+ return hash((self.key, self.this, self.args["is_string"]))
+
+ @classmethod
+ def number(cls, number):
+ return cls(this=str(number), is_string=False)
+
+ @classmethod
+ def string(cls, string):
+ return cls(this=str(string), is_string=True)
+
+
+class Join(Expression):
+ arg_types = {
+ "this": True,
+ "on": False,
+ "side": False,
+ "kind": False,
+ "using": False,
+ }
+
+ @property
+ def kind(self):
+ return self.text("kind").upper()
+
+ @property
+ def side(self):
+ return self.text("side").upper()
+
+ def on(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Append to or set the ON expressions.
+
+ Example:
+ >>> import sqlglot
+ >>> sqlglot.parse_one("JOIN x", into=Join).on("y = 1").sql()
+ 'JOIN x ON y = 1'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ Multiple expressions are combined with an AND operator.
+ append (bool): if `True`, AND the new expressions to any existing expression.
+ Otherwise, this resets the expression.
+ dialect (str): the dialect used to parse the input expressions.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Join: the modified join expression.
+ """
+ join = _apply_conjunction_builder(
+ *expressions,
+ instance=self,
+ arg="on",
+ append=append,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ if join.kind == "CROSS":
+ join.set("kind", None)
+
+ return join
+
+
+class Lateral(DerivedTable):
+ arg_types = {"this": True, "outer": False, "alias": False}
+
+
+# Clickhouse FROM FINAL modifier
+# https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier
+class Final(Expression):
+ pass
+
+
+class Offset(Expression):
+ arg_types = {"this": False, "expression": True}
+
+
+class Order(Expression):
+ arg_types = {"this": False, "expressions": True}
+
+
+# hive specific sorts
+# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+SortBy
+class Cluster(Order):
+ pass
+
+
+class Distribute(Order):
+ pass
+
+
+class Sort(Order):
+ pass
+
+
+class Ordered(Expression):
+ arg_types = {"this": True, "desc": True, "nulls_first": True}
+
+
+class Properties(Expression):
+ arg_types = {"expressions": True}
+
+
+class Property(Expression):
+ arg_types = {"this": True, "value": True}
+
+
+class TableFormatProperty(Property):
+ pass
+
+
+class PartitionedByProperty(Property):
+ pass
+
+
+class FileFormatProperty(Property):
+ pass
+
+
+class LocationProperty(Property):
+ pass
+
+
+class EngineProperty(Property):
+ pass
+
+
+class AutoIncrementProperty(Property):
+ pass
+
+
+class CharacterSetProperty(Property):
+ arg_types = {"this": True, "value": True, "default": True}
+
+
+class CollateProperty(Property):
+ pass
+
+
+class SchemaCommentProperty(Property):
+ pass
+
+
+class AnonymousProperty(Property):
+ pass
+
+
+class Qualify(Expression):
+ pass
+
+
+class Reference(Expression):
+ arg_types = {"this": True, "expressions": True}
+
+
+class Table(Expression):
+ arg_types = {"this": True, "db": False, "catalog": False}
+
+
+class Tuple(Expression):
+ arg_types = {"expressions": False}
+
+
+class Subqueryable:
+ def subquery(self, alias=None, copy=True):
+ """
+ Convert this expression to an aliased expression that can be used as a Subquery.
+
+ Example:
+ >>> subquery = Select().select("x").from_("tbl").subquery()
+ >>> Select().select("x").from_(subquery).sql()
+ 'SELECT x FROM (SELECT x FROM tbl)'
+
+ Args:
+ alias (str or Identifier): an optional alias for the subquery
+ copy (bool): if `False`, modify this expression instance in-place.
+
+ Returns:
+ Alias: the subquery
+ """
+ instance = _maybe_copy(self, copy)
+ return Subquery(
+ this=instance,
+ alias=TableAlias(this=to_identifier(alias)),
+ )
+
+ @property
+ def ctes(self):
+ with_ = self.args.get("with")
+ if not with_:
+ return []
+ return with_.expressions
+
+ def with_(
+ self,
+ alias,
+ as_,
+ recursive=None,
+ append=True,
+ dialect=None,
+ copy=True,
+ **opts,
+ ):
+ """
+ Append to or set the common table expressions.
+
+ Example:
+ >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql()
+ 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2'
+
+ Args:
+ alias (str or Expression): the SQL code string to parse as the table name.
+ If an `Expression` instance is passed, this is used as-is.
+ as_ (str or Expression): the SQL code string to parse as the table expression.
+ If an `Expression` instance is passed, it will be used as-is.
+ recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ alias_expression = maybe_parse(
+ alias,
+ dialect=dialect,
+ into=TableAlias,
+ **opts,
+ )
+ as_expression = maybe_parse(
+ as_,
+ dialect=dialect,
+ **opts,
+ )
+ cte = CTE(
+ this=as_expression,
+ alias=alias_expression,
+ )
+ return _apply_child_list_builder(
+ cte,
+ instance=self,
+ arg="with",
+ append=append,
+ copy=copy,
+ into=With,
+ properties={"recursive": recursive or False},
+ )
+
+
+QUERY_MODIFIERS = {
+ "laterals": False,
+ "joins": False,
+ "where": False,
+ "group": False,
+ "having": False,
+ "qualify": False,
+ "window": False,
+ "distribute": False,
+ "sort": False,
+ "cluster": False,
+ "order": False,
+ "limit": False,
+ "offset": False,
+}
+
+
+class Union(Subqueryable, Expression):
+ arg_types = {
+ "with": False,
+ "this": True,
+ "expression": True,
+ "distinct": False,
+ **QUERY_MODIFIERS,
+ }
+
+ @property
+ def named_selects(self):
+ return self.args["this"].unnest().named_selects
+
+ @property
+ def left(self):
+ return self.this
+
+ @property
+ def right(self):
+ return self.expression
+
+
+class Except(Union):
+ pass
+
+
+class Intersect(Union):
+ pass
+
+
+class Unnest(DerivedTable):
+ arg_types = {
+ "expressions": True,
+ "ordinality": False,
+ "alias": False,
+ }
+
+
+class Update(Expression):
+ arg_types = {
+ "with": False,
+ "this": True,
+ "expressions": True,
+ "from": False,
+ "where": False,
+ }
+
+
+class Values(Expression):
+ arg_types = {"expressions": True}
+
+
+class Var(Expression):
+ pass
+
+
+class Schema(Expression):
+ arg_types = {"this": False, "expressions": True}
+
+
+class Select(Subqueryable, Expression):
+ arg_types = {
+ "with": False,
+ "expressions": False,
+ "hint": False,
+ "distinct": False,
+ "from": False,
+ **QUERY_MODIFIERS,
+ }
+
+ def from_(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Set the FROM expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").sql()
+ 'SELECT x FROM tbl'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If a `From` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `From`.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this flattens all the `From` expression into a single expression.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_child_list_builder(
+ *expressions,
+ instance=self,
+ arg="from",
+ append=append,
+ copy=copy,
+ prefix="FROM",
+ into=From,
+ dialect=dialect,
+ **opts,
+ )
+
+ def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Set the GROUP BY expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x", "COUNT(1)").group_by("x").sql()
+ 'SELECT x, COUNT(1) FROM tbl GROUP BY x'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If a `Group` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Group`.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this flattens all the `Group` expression into a single expression.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_child_list_builder(
+ *expressions,
+ instance=self,
+ arg="group",
+ append=append,
+ copy=copy,
+ prefix="GROUP BY",
+ into=Group,
+ dialect=dialect,
+ **opts,
+ )
+
+ def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Set the ORDER BY expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").order_by("x DESC").sql()
+ 'SELECT x FROM tbl ORDER BY x DESC'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If a `Group` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Order`.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this flattens all the `Order` expression into a single expression.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_child_list_builder(
+ *expressions,
+ instance=self,
+ arg="order",
+ append=append,
+ copy=copy,
+ prefix="ORDER BY",
+ into=Order,
+ dialect=dialect,
+ **opts,
+ )
+
+ def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Set the SORT BY expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").sort_by("x DESC").sql()
+ 'SELECT x FROM tbl SORT BY x DESC'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If a `Group` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `SORT`.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this flattens all the `Order` expression into a single expression.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_child_list_builder(
+ *expressions,
+ instance=self,
+ arg="sort",
+ append=append,
+ copy=copy,
+ prefix="SORT BY",
+ into=Sort,
+ dialect=dialect,
+ **opts,
+ )
+
+ def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Set the CLUSTER BY expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql()
+ 'SELECT x FROM tbl CLUSTER BY x DESC'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If a `Group` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Cluster`.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this flattens all the `Order` expression into a single expression.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_child_list_builder(
+ *expressions,
+ instance=self,
+ arg="cluster",
+ append=append,
+ copy=copy,
+ prefix="CLUSTER BY",
+ into=Cluster,
+ dialect=dialect,
+ **opts,
+ )
+
+ def limit(self, expression, dialect=None, copy=True, **opts):
+ """
+ Set the LIMIT expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").limit(10).sql()
+ 'SELECT x FROM tbl LIMIT 10'
+
+ Args:
+ expression (str or int or Expression): the SQL code string to parse.
+ This can also be an integer.
+ If a `Limit` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Limit`.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_builder(
+ expression=expression,
+ instance=self,
+ arg="limit",
+ into=Limit,
+ prefix="LIMIT",
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def offset(self, expression, dialect=None, copy=True, **opts):
+ """
+ Set the OFFSET expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").offset(10).sql()
+ 'SELECT x FROM tbl OFFSET 10'
+
+ Args:
+ expression (str or int or Expression): the SQL code string to parse.
+ This can also be an integer.
+ If a `Offset` instance is passed, this is used as-is.
+ If another `Expression` instance is passed, it will be wrapped in a `Offset`.
+ dialect (str): the dialect used to parse the input expression.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_builder(
+ expression=expression,
+ instance=self,
+ arg="offset",
+ into=Offset,
+ prefix="OFFSET",
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def select(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Append to or set the SELECT expressions.
+
+ Example:
+ >>> Select().select("x", "y").sql()
+ 'SELECT x, y'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect (str): the dialect used to parse the input expressions.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_list_builder(
+ *expressions,
+ instance=self,
+ arg="expressions",
+ append=append,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Append to or set the LATERAL expressions.
+
+ Example:
+ >>> Select().select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl").sql()
+ 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ dialect (str): the dialect used to parse the input expressions.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_list_builder(
+ *expressions,
+ instance=self,
+ arg="laterals",
+ append=append,
+ into=Lateral,
+ prefix="LATERAL VIEW",
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def join(
+ self,
+ expression,
+ on=None,
+ append=True,
+ join_type=None,
+ join_alias=None,
+ dialect=None,
+ copy=True,
+ **opts,
+ ):
+ """
+ Append to or set the JOIN expressions.
+
+ Example:
+ >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql()
+ 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y'
+
+ Use `join_type` to change the type of join:
+
+ >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql()
+ 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y'
+
+ Args:
+ expression (str or Expression): the SQL code string to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ on (str or Expression): optionally specify the join criteria as a SQL string.
+ If an `Expression` instance is passed, it will be used as-is.
+ append (bool): if `True`, add to any existing expressions.
+ Otherwise, this resets the expressions.
+ join_type (str): If set, alter the parsed join type
+ dialect (str): the dialect used to parse the input expressions.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ parse_args = {"dialect": dialect, **opts}
+
+ try:
+ expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args)
+ except ParseError:
+ expression = maybe_parse(expression, into=(Join, Expression), **parse_args)
+
+ join = expression if isinstance(expression, Join) else Join(this=expression)
+
+ if isinstance(join.this, Select):
+ join.this.replace(join.this.subquery())
+
+ if join_type:
+ side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args)
+ if side:
+ join.set("side", side.text)
+ if kind:
+ join.set("kind", kind.text)
+
+ if on:
+ on = and_(*ensure_list(on), dialect=dialect, **opts)
+ join.set("on", on)
+
+ if join_alias:
+ join.set("this", alias_(join.args["this"], join_alias, table=True))
+ return _apply_list_builder(
+ join,
+ instance=self,
+ arg="joins",
+ append=append,
+ copy=copy,
+ **opts,
+ )
+
+ def where(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Append to or set the WHERE expressions.
+
+ Example:
+ >>> Select().select("x").from_("tbl").where("x = 'a' OR x < 'b'").sql()
+ "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'"
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ Multiple expressions are combined with an AND operator.
+ append (bool): if `True`, AND the new expressions to any existing expression.
+ Otherwise, this resets the expression.
+ dialect (str): the dialect used to parse the input expressions.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_conjunction_builder(
+ *expressions,
+ instance=self,
+ arg="where",
+ append=append,
+ into=Where,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def having(self, *expressions, append=True, dialect=None, copy=True, **opts):
+ """
+ Append to or set the HAVING expressions.
+
+ Example:
+ >>> Select().select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 3").sql()
+ 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an `Expression` instance is passed, it will be used as-is.
+ Multiple expressions are combined with an AND operator.
+ append (bool): if `True`, AND the new expressions to any existing expression.
+ Otherwise, this resets the expression.
+ dialect (str): the dialect used to parse the input expressions.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input expressions.
+
+ Returns:
+ Select: the modified expression.
+ """
+ return _apply_conjunction_builder(
+ *expressions,
+ instance=self,
+ arg="having",
+ append=append,
+ into=Having,
+ dialect=dialect,
+ copy=copy,
+ **opts,
+ )
+
+ def distinct(self, distinct=True, copy=True):
+ """
+ Set the OFFSET expression.
+
+ Example:
+ >>> Select().from_("tbl").select("x").distinct().sql()
+ 'SELECT DISTINCT x FROM tbl'
+
+ Args:
+ distinct (bool): whether the Select should be distinct
+ copy (bool): if `False`, modify this expression instance in-place.
+
+ Returns:
+ Select: the modified expression.
+ """
+ instance = _maybe_copy(self, copy)
+ instance.set("distinct", Distinct() if distinct else None)
+ return instance
+
+ def ctas(self, table, properties=None, dialect=None, copy=True, **opts):
+ """
+ Convert this expression to a CREATE TABLE AS statement.
+
+ Example:
+ >>> Select().select("*").from_("tbl").ctas("x").sql()
+ 'CREATE TABLE x AS SELECT * FROM tbl'
+
+ Args:
+ table (str or Expression): the SQL code string to parse as the table name.
+ If another `Expression` instance is passed, it will be used as-is.
+ properties (dict): an optional mapping of table properties
+ dialect (str): the dialect used to parse the input table.
+ copy (bool): if `False`, modify this expression instance in-place.
+ opts (kwargs): other options to use to parse the input table.
+
+ Returns:
+ Create: the CREATE TABLE AS expression
+ """
+ instance = _maybe_copy(self, copy)
+ table_expression = maybe_parse(
+ table,
+ into=Table,
+ dialect=dialect,
+ **opts,
+ )
+ properties_expression = None
+ if properties:
+ properties_str = " ".join(
+ [
+ f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
+ for k, v in properties.items()
+ ]
+ )
+ properties_expression = maybe_parse(
+ properties_str,
+ into=Properties,
+ dialect=dialect,
+ **opts,
+ )
+
+ return Create(
+ this=table_expression,
+ kind="table",
+ expression=instance,
+ properties=properties_expression,
+ )
+
+ @property
+ def named_selects(self):
+ return [e.alias_or_name for e in self.expressions if e.alias_or_name]
+
+ @property
+ def selects(self):
+ return self.expressions
+
+
+class Subquery(DerivedTable):
+ arg_types = {
+ "this": True,
+ "alias": False,
+ **QUERY_MODIFIERS,
+ }
+
+ def unnest(self):
+ """
+ Returns the first non subquery.
+ """
+ expression = self
+ while isinstance(expression, Subquery):
+ expression = expression.this
+ return expression
+
+
+class TableSample(Expression):
+ arg_types = {
+ "this": False,
+ "method": False,
+ "bucket_numerator": False,
+ "bucket_denominator": False,
+ "bucket_field": False,
+ "percent": False,
+ "rows": False,
+ "size": False,
+ }
+
+
+class Window(Expression):
+ arg_types = {
+ "this": True,
+ "partition_by": False,
+ "order": False,
+ "spec": False,
+ "alias": False,
+ }
+
+
+class WindowSpec(Expression):
+ arg_types = {
+ "kind": False,
+ "start": False,
+ "start_side": False,
+ "end": False,
+ "end_side": False,
+ }
+
+
+class Where(Expression):
+ pass
+
+
+class Star(Expression):
+ arg_types = {"except": False, "replace": False}
+
+ @property
+ def name(self):
+ return "*"
+
+
+class Placeholder(Expression):
+ arg_types = {}
+
+
+class Null(Condition):
+ arg_types = {}
+
+
+class Boolean(Condition):
+ pass
+
+
+class DataType(Expression):
+ arg_types = {
+ "this": True,
+ "expressions": False,
+ "nested": False,
+ }
+
+ class Type(AutoName):
+ CHAR = auto()
+ NCHAR = auto()
+ VARCHAR = auto()
+ NVARCHAR = auto()
+ TEXT = auto()
+ BINARY = auto()
+ INT = auto()
+ TINYINT = auto()
+ SMALLINT = auto()
+ BIGINT = auto()
+ FLOAT = auto()
+ DOUBLE = auto()
+ DECIMAL = auto()
+ BOOLEAN = auto()
+ JSON = auto()
+ TIMESTAMP = auto()
+ TIMESTAMPTZ = auto()
+ DATE = auto()
+ DATETIME = auto()
+ ARRAY = auto()
+ MAP = auto()
+ UUID = auto()
+ GEOGRAPHY = auto()
+ STRUCT = auto()
+ NULLABLE = auto()
+
+ @classmethod
+ def build(cls, dtype, **kwargs):
+ return DataType(
+ this=dtype
+ if isinstance(dtype, DataType.Type)
+ else DataType.Type[dtype.upper()],
+ **kwargs,
+ )
+
+
+class StructKwarg(Expression):
+ arg_types = {"this": True, "expression": True}
+
+
+# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...)
+class SubqueryPredicate(Predicate):
+ pass
+
+
+class All(SubqueryPredicate):
+ pass
+
+
+class Any(SubqueryPredicate):
+ pass
+
+
+class Exists(SubqueryPredicate):
+ pass
+
+
+# Commands to interact with the databases or engines
+# These expressions don't truly parse the expression and consume
+# whatever exists as a string until the end or a semicolon
+class Command(Expression):
+ arg_types = {"this": True, "expression": False}
+
+
+# Binary Expressions
+# (ADD a b)
+# (FROM table selects)
+class Binary(Expression):
+ arg_types = {"this": True, "expression": True}
+
+ @property
+ def left(self):
+ return self.this
+
+ @property
+ def right(self):
+ return self.expression
+
+
+class Add(Binary):
+ pass
+
+
+class Connector(Binary, Condition):
+ pass
+
+
+class And(Connector):
+ pass
+
+
+class Or(Connector):
+ pass
+
+
+class BitwiseAnd(Binary):
+ pass
+
+
+class BitwiseLeftShift(Binary):
+ pass
+
+
+class BitwiseOr(Binary):
+ pass
+
+
+class BitwiseRightShift(Binary):
+ pass
+
+
+class BitwiseXor(Binary):
+ pass
+
+
+class Div(Binary):
+ pass
+
+
+class Dot(Binary):
+ pass
+
+
+class DPipe(Binary):
+ pass
+
+
+class EQ(Binary, Predicate):
+ pass
+
+
+class Escape(Binary):
+ pass
+
+
+class GT(Binary, Predicate):
+ pass
+
+
+class GTE(Binary, Predicate):
+ pass
+
+
+class ILike(Binary, Predicate):
+ pass
+
+
+class IntDiv(Binary):
+ pass
+
+
+class Is(Binary, Predicate):
+ pass
+
+
+class Like(Binary, Predicate):
+ pass
+
+
+class LT(Binary, Predicate):
+ pass
+
+
+class LTE(Binary, Predicate):
+ pass
+
+
+class Mod(Binary):
+ pass
+
+
+class Mul(Binary):
+ pass
+
+
+class NEQ(Binary, Predicate):
+ pass
+
+
+class Sub(Binary):
+ pass
+
+
+# Unary Expressions
+# (NOT a)
+class Unary(Expression):
+ pass
+
+
+class BitwiseNot(Unary):
+ pass
+
+
+class Not(Unary, Condition):
+ pass
+
+
+class Paren(Unary, Condition):
+ pass
+
+
+class Neg(Unary):
+ pass
+
+
+# Special Functions
+class Alias(Expression):
+ arg_types = {"this": True, "alias": False}
+
+
+class Aliases(Expression):
+ arg_types = {"this": True, "expressions": True}
+
+ @property
+ def aliases(self):
+ return self.expressions
+
+
+class AtTimeZone(Expression):
+ arg_types = {"this": True, "zone": True}
+
+
+class Between(Predicate):
+ arg_types = {"this": True, "low": True, "high": True}
+
+
+class Bracket(Condition):
+ arg_types = {"this": True, "expressions": True}
+
+
+class Distinct(Expression):
+ arg_types = {"this": False, "on": False}
+
+
+class In(Predicate):
+ arg_types = {"this": True, "expressions": False, "query": False, "unnest": False}
+
+
+class TimeUnit(Expression):
+ """Automatically converts unit arg into a var."""
+
+ arg_types = {"unit": False}
+
+ def __init__(self, **args):
+ unit = args.get("unit")
+ if isinstance(unit, Column):
+ args["unit"] = Var(this=unit.name)
+ elif isinstance(unit, Week):
+ unit.set("this", Var(this=unit.this.name))
+ super().__init__(**args)
+
+
+class Interval(TimeUnit):
+ arg_types = {"this": True, "unit": False}
+
+
+class IgnoreNulls(Expression):
+ pass
+
+
+# Functions
+class Func(Condition):
+ """
+ The base class for all function expressions.
+
+ Attributes
+ is_var_len_args (bool): if set to True the last argument defined in
+ arg_types will be treated as a variable length argument and the
+ argument's value will be stored as a list.
+ _sql_names (list): determines the SQL name (1st item in the list) and
+ aliases (subsequent items) for this function expression. These
+ values are used to map this node to a name during parsing as well
+ as to provide the function's name during SQL string generation. By
+ default the SQL name is set to the expression's class name transformed
+ to snake case.
+ """
+
+ is_var_len_args = False
+
+ @classmethod
+ def from_arg_list(cls, args):
+ args_num = len(args)
+
+ all_arg_keys = list(cls.arg_types)
+ # If this function supports variable length argument treat the last argument as such.
+ non_var_len_arg_keys = (
+ all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys
+ )
+
+ args_dict = {}
+ arg_idx = 0
+ for arg_key in non_var_len_arg_keys:
+ if arg_idx >= args_num:
+ break
+ if args[arg_idx] is not None:
+ args_dict[arg_key] = args[arg_idx]
+ arg_idx += 1
+
+ if arg_idx < args_num and cls.is_var_len_args:
+ args_dict[all_arg_keys[-1]] = args[arg_idx:]
+ return cls(**args_dict)
+
+ @classmethod
+ def sql_names(cls):
+ if cls is Func:
+ raise NotImplementedError(
+ "SQL name is only supported by concrete function implementations"
+ )
+ if not hasattr(cls, "_sql_names"):
+ cls._sql_names = [camel_to_snake_case(cls.__name__)]
+ return cls._sql_names
+
+ @classmethod
+ def sql_name(cls):
+ return cls.sql_names()[0]
+
+ @classmethod
+ def default_parser_mappings(cls):
+ return {name: cls.from_arg_list for name in cls.sql_names()}
+
+
+class AggFunc(Func):
+ pass
+
+
+class Abs(Func):
+ pass
+
+
+class Anonymous(Func):
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
+
+
+class ApproxDistinct(AggFunc):
+ arg_types = {"this": True, "accuracy": False}
+
+
+class Array(Func):
+ arg_types = {"expressions": False}
+ is_var_len_args = True
+
+
+class ArrayAgg(AggFunc):
+ pass
+
+
+class ArrayAll(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class ArrayAny(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class ArrayContains(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class ArrayFilter(Func):
+ arg_types = {"this": True, "expression": True}
+ _sql_names = ["FILTER", "ARRAY_FILTER"]
+
+
+class ArraySize(Func):
+ pass
+
+
+class ArraySort(Func):
+ arg_types = {"this": True, "expression": False}
+
+
+class ArraySum(Func):
+ pass
+
+
+class ArrayUnionAgg(AggFunc):
+ pass
+
+
+class Avg(AggFunc):
+ pass
+
+
+class AnyValue(AggFunc):
+ pass
+
+
+class Case(Func):
+ arg_types = {"this": False, "ifs": True, "default": False}
+
+
+class Cast(Func):
+ arg_types = {"this": True, "to": True}
+
+
+class TryCast(Cast):
+ pass
+
+
+class Ceil(Func):
+ _sql_names = ["CEIL", "CEILING"]
+
+
+class Coalesce(Func):
+ arg_types = {"this": True, "expressions": False}
+ is_var_len_args = True
+
+
+class ConcatWs(Func):
+ arg_types = {"expressions": False}
+ is_var_len_args = True
+
+
+class Count(AggFunc):
+ pass
+
+
+class CurrentDate(Func):
+ arg_types = {"this": False}
+
+
+class CurrentDatetime(Func):
+ arg_types = {"this": False}
+
+
+class CurrentTime(Func):
+ arg_types = {"this": False}
+
+
+class CurrentTimestamp(Func):
+ arg_types = {"this": False}
+
+
+class DateAdd(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class DateSub(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class DateDiff(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class DateTrunc(Func, TimeUnit):
+ arg_types = {"this": True, "unit": True, "zone": False}
+
+
+class DatetimeAdd(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class DatetimeSub(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class DatetimeDiff(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class DatetimeTrunc(Func, TimeUnit):
+ arg_types = {"this": True, "unit": True, "zone": False}
+
+
+class Extract(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class TimestampAdd(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TimestampSub(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TimestampDiff(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TimestampTrunc(Func, TimeUnit):
+ arg_types = {"this": True, "unit": True, "zone": False}
+
+
+class TimeAdd(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TimeSub(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TimeDiff(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TimeTrunc(Func, TimeUnit):
+ arg_types = {"this": True, "unit": True, "zone": False}
+
+
+class DateStrToDate(Func):
+ pass
+
+
+class DateToDateStr(Func):
+ pass
+
+
+class DateToDi(Func):
+ pass
+
+
+class Day(Func):
+ pass
+
+
+class DiToDate(Func):
+ pass
+
+
+class Exp(Func):
+ pass
+
+
+class Explode(Func):
+ pass
+
+
+class Floor(Func):
+ pass
+
+
+class Greatest(Func):
+ arg_types = {"this": True, "expressions": True}
+ is_var_len_args = True
+
+
+class If(Func):
+ arg_types = {"this": True, "true": True, "false": False}
+
+
+class IfNull(Func):
+ arg_types = {"this": True, "expression": False}
+ _sql_names = ["IFNULL", "NVL"]
+
+
+class Initcap(Func):
+ pass
+
+
+class JSONExtract(Func):
+ arg_types = {"this": True, "path": True}
+ _sql_names = ["JSON_EXTRACT"]
+
+
+class JSONExtractScalar(JSONExtract):
+ _sql_names = ["JSON_EXTRACT_SCALAR"]
+
+
+class JSONBExtract(JSONExtract):
+ _sql_names = ["JSONB_EXTRACT"]
+
+
+class JSONBExtractScalar(JSONExtract):
+ _sql_names = ["JSONB_EXTRACT_SCALAR"]
+
+
+class Least(Func):
+ arg_types = {"this": True, "expressions": True}
+ is_var_len_args = True
+
+
+class Length(Func):
+ pass
+
+
+class Levenshtein(Func):
+ arg_types = {"this": True, "expression": False}
+
+
+class Ln(Func):
+ pass
+
+
+class Log(Func):
+ arg_types = {"this": True, "expression": False}
+
+
+class Log2(Func):
+ pass
+
+
+class Log10(Func):
+ pass
+
+
+class Lower(Func):
+ pass
+
+
+class Map(Func):
+ arg_types = {"keys": True, "values": True}
+
+
+class Max(AggFunc):
+ pass
+
+
+class Min(AggFunc):
+ pass
+
+
+class Month(Func):
+ pass
+
+
+class Nvl2(Func):
+ arg_types = {"this": True, "true": True, "false": False}
+
+
+class Posexplode(Func):
+ pass
+
+
+class Pow(Func):
+ arg_types = {"this": True, "power": True}
+ _sql_names = ["POWER", "POW"]
+
+
+class Quantile(AggFunc):
+ arg_types = {"this": True, "quantile": True}
+
+
+class Reduce(Func):
+ arg_types = {"this": True, "initial": True, "merge": True, "finish": True}
+
+
+class RegexpLike(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class RegexpSplit(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class Round(Func):
+ arg_types = {"this": True, "decimals": False}
+
+
+class SafeDivide(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class SetAgg(AggFunc):
+ pass
+
+
+class SortArray(Func):
+ arg_types = {"this": True, "asc": False}
+
+
+class Split(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class Substring(Func):
+ arg_types = {"this": True, "start": True, "length": False}
+
+
+class StrPosition(Func):
+ arg_types = {"this": True, "substr": True, "position": False}
+
+
+class StrToDate(Func):
+ arg_types = {"this": True, "format": True}
+
+
+class StrToTime(Func):
+ arg_types = {"this": True, "format": True}
+
+
+class StrToUnix(Func):
+ arg_types = {"this": True, "format": True}
+
+
+class Struct(Func):
+ arg_types = {"expressions": True}
+ is_var_len_args = True
+
+
+class StructExtract(Func):
+ arg_types = {"this": True, "expression": True}
+
+
+class Sum(AggFunc):
+ pass
+
+
+class Sqrt(Func):
+ pass
+
+
+class Stddev(AggFunc):
+ pass
+
+
+class StddevPop(AggFunc):
+ pass
+
+
+class StddevSamp(AggFunc):
+ pass
+
+
+class TimeToStr(Func):
+ arg_types = {"this": True, "format": True}
+
+
+class TimeToTimeStr(Func):
+ pass
+
+
+class TimeToUnix(Func):
+ pass
+
+
+class TimeStrToDate(Func):
+ pass
+
+
+class TimeStrToTime(Func):
+ pass
+
+
+class TimeStrToUnix(Func):
+ pass
+
+
+class TsOrDsAdd(Func, TimeUnit):
+ arg_types = {"this": True, "expression": True, "unit": False}
+
+
+class TsOrDsToDateStr(Func):
+ pass
+
+
+class TsOrDsToDate(Func):
+ arg_types = {"this": True, "format": False}
+
+
+class TsOrDiToDi(Func):
+ pass
+
+
+class UnixToStr(Func):
+ arg_types = {"this": True, "format": True}
+
+
+class UnixToTime(Func):
+ arg_types = {"this": True, "scale": False}
+
+ SECONDS = Literal.string("seconds")
+ MILLIS = Literal.string("millis")
+ MICROS = Literal.string("micros")
+
+
+class UnixToTimeStr(Func):
+ pass
+
+
+class Upper(Func):
+ pass
+
+
+class Variance(AggFunc):
+ _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"]
+
+
+class VariancePop(AggFunc):
+ _sql_names = ["VARIANCE_POP", "VAR_POP"]
+
+
+class Week(Func):
+ arg_types = {"this": True, "mode": False}
+
+
+class Year(Func):
+ pass
+
+
+def _norm_args(expression):
+ args = {}
+
+ for k, arg in expression.args.items():
+ if isinstance(arg, list):
+ arg = [_norm_arg(a) for a in arg]
+ else:
+ arg = _norm_arg(arg)
+
+ if arg is not None:
+ args[k] = arg
+
+ return args
+
+
+def _norm_arg(arg):
+ return arg.lower() if isinstance(arg, str) else arg
+
+
+def _all_functions():
+ return [
+ obj
+ for _, obj in inspect.getmembers(
+ sys.modules[__name__],
+ lambda obj: inspect.isclass(obj)
+ and issubclass(obj, Func)
+ and obj not in (AggFunc, Anonymous, Func),
+ )
+ ]
+
+
+ALL_FUNCTIONS = _all_functions()
+
+
+def maybe_parse(
+ sql_or_expression,
+ *,
+ into=None,
+ dialect=None,
+ prefix=None,
+ **opts,
+):
+ """Gracefully handle a possible string or expression.
+
+ Example:
+ >>> maybe_parse("1")
+ (LITERAL this: 1, is_string: False)
+ >>> maybe_parse(to_identifier("x"))
+ (IDENTIFIER this: x, quoted: False)
+
+ Args:
+ sql_or_expression (str or Expression): the SQL code string or an expression
+ into (Expression): the SQLGlot Expression to parse into
+ dialect (str): the dialect used to parse the input expressions (in the case that an
+ input expression is a SQL string).
+ prefix (str): a string to prefix the sql with before it gets parsed
+ (automatically includes a space)
+ **opts: other options to use to parse the input expressions (again, in the case
+ that an input expression is a SQL string).
+
+ Returns:
+ Expression: the parsed or given expression.
+ """
+ if isinstance(sql_or_expression, Expression):
+ return sql_or_expression
+
+ import sqlglot
+
+ sql = str(sql_or_expression)
+ if prefix:
+ sql = f"{prefix} {sql}"
+ return sqlglot.parse_one(sql, read=dialect, into=into, **opts)
+
+
+def _maybe_copy(instance, copy=True):
+ return instance.copy() if copy else instance
+
+
+def _is_wrong_expression(expression, into):
+ return isinstance(expression, Expression) and not isinstance(expression, into)
+
+
+def _apply_builder(
+ expression,
+ instance,
+ arg,
+ copy=True,
+ prefix=None,
+ into=None,
+ dialect=None,
+ **opts,
+):
+ if _is_wrong_expression(expression, into):
+ expression = into(this=expression)
+ instance = _maybe_copy(instance, copy)
+ expression = maybe_parse(
+ sql_or_expression=expression,
+ prefix=prefix,
+ into=into,
+ dialect=dialect,
+ **opts,
+ )
+ instance.set(arg, expression)
+ return instance
+
+
+def _apply_child_list_builder(
+ *expressions,
+ instance,
+ arg,
+ append=True,
+ copy=True,
+ prefix=None,
+ into=None,
+ dialect=None,
+ properties=None,
+ **opts,
+):
+ instance = _maybe_copy(instance, copy)
+ parsed = []
+ for expression in expressions:
+ if _is_wrong_expression(expression, into):
+ expression = into(expressions=[expression])
+ expression = maybe_parse(
+ expression,
+ into=into,
+ dialect=dialect,
+ prefix=prefix,
+ **opts,
+ )
+ parsed.extend(expression.expressions)
+
+ existing = instance.args.get(arg)
+ if append and existing:
+ parsed = existing.expressions + parsed
+
+ child = into(expressions=parsed)
+ for k, v in (properties or {}).items():
+ child.set(k, v)
+ instance.set(arg, child)
+ return instance
+
+
+def _apply_list_builder(
+ *expressions,
+ instance,
+ arg,
+ append=True,
+ copy=True,
+ prefix=None,
+ into=None,
+ dialect=None,
+ **opts,
+):
+ inst = _maybe_copy(instance, copy)
+
+ expressions = [
+ maybe_parse(
+ sql_or_expression=expression,
+ into=into,
+ prefix=prefix,
+ dialect=dialect,
+ **opts,
+ )
+ for expression in expressions
+ ]
+
+ existing_expressions = inst.args.get(arg)
+ if append and existing_expressions:
+ expressions = existing_expressions + expressions
+
+ inst.set(arg, expressions)
+ return inst
+
+
+def _apply_conjunction_builder(
+ *expressions,
+ instance,
+ arg,
+ into=None,
+ append=True,
+ copy=True,
+ dialect=None,
+ **opts,
+):
+ expressions = [exp for exp in expressions if exp is not None and exp != ""]
+ if not expressions:
+ return instance
+
+ inst = _maybe_copy(instance, copy)
+
+ existing = inst.args.get(arg)
+ if append and existing is not None:
+ expressions = [existing.this if into else existing] + list(expressions)
+
+ node = and_(*expressions, dialect=dialect, **opts)
+
+ inst.set(arg, into(this=node) if into else node)
+ return inst
+
+
+def _combine(expressions, operator, dialect=None, **opts):
+ expressions = [
+ condition(expression, dialect=dialect, **opts) for expression in expressions
+ ]
+ this = expressions[0]
+ if expressions[1:]:
+ this = _wrap_operator(this)
+ for expression in expressions[1:]:
+ this = operator(this=this, expression=_wrap_operator(expression))
+ return this
+
+
+def _wrap_operator(expression):
+ if isinstance(expression, (And, Or, Not)):
+ expression = Paren(this=expression)
+ return expression
+
+
+def select(*expressions, dialect=None, **opts):
+ """
+ Initializes a syntax tree from one or multiple SELECT expressions.
+
+ Example:
+ >>> select("col1", "col2").from_("tbl").sql()
+ 'SELECT col1, col2 FROM tbl'
+
+ Args:
+ *expressions (str or Expression): the SQL code string to parse as the expressions of a
+ SELECT statement. If an Expression instance is passed, this is used as-is.
+ dialect (str): the dialect used to parse the input expressions (in the case that an
+ input expression is a SQL string).
+ **opts: other options to use to parse the input expressions (again, in the case
+ that an input expression is a SQL string).
+
+ Returns:
+ Select: the syntax tree for the SELECT statement.
+ """
+ return Select().select(*expressions, dialect=dialect, **opts)
+
+
+def from_(*expressions, dialect=None, **opts):
+ """
+ Initializes a syntax tree from a FROM expression.
+
+ Example:
+ >>> from_("tbl").select("col1", "col2").sql()
+ 'SELECT col1, col2 FROM tbl'
+
+ Args:
+ *expressions (str or Expression): the SQL code string to parse as the FROM expressions of a
+ SELECT statement. If an Expression instance is passed, this is used as-is.
+ dialect (str): the dialect used to parse the input expression (in the case that the
+ input expression is a SQL string).
+ **opts: other options to use to parse the input expressions (again, in the case
+ that the input expression is a SQL string).
+
+ Returns:
+ Select: the syntax tree for the SELECT statement.
+ """
+ return Select().from_(*expressions, dialect=dialect, **opts)
+
+
+def condition(expression, dialect=None, **opts):
+ """
+ Initialize a logical condition expression.
+
+ Example:
+ >>> condition("x=1").sql()
+ 'x = 1'
+
+ This is helpful for composing larger logical syntax trees:
+ >>> where = condition("x=1")
+ >>> where = where.and_("y=1")
+ >>> Select().from_("tbl").select("*").where(where).sql()
+ 'SELECT * FROM tbl WHERE x = 1 AND y = 1'
+
+ Args:
+ *expression (str or Expression): the SQL code string to parse.
+ If an Expression instance is passed, this is used as-is.
+ dialect (str): the dialect used to parse the input expression (in the case that the
+ input expression is a SQL string).
+ **opts: other options to use to parse the input expressions (again, in the case
+ that the input expression is a SQL string).
+
+ Returns:
+ Condition: the expression
+ """
+ return maybe_parse(
+ expression,
+ into=Condition,
+ dialect=dialect,
+ **opts,
+ )
+
+
+def and_(*expressions, dialect=None, **opts):
+ """
+ Combine multiple conditions with an AND logical operator.
+
+ Example:
+ >>> and_("x=1", and_("y=1", "z=1")).sql()
+ 'x = 1 AND (y = 1 AND z = 1)'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an Expression instance is passed, this is used as-is.
+ dialect (str): the dialect used to parse the input expression.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ And: the new condition
+ """
+ return _combine(expressions, And, dialect, **opts)
+
+
+def or_(*expressions, dialect=None, **opts):
+ """
+ Combine multiple conditions with an OR logical operator.
+
+ Example:
+ >>> or_("x=1", or_("y=1", "z=1")).sql()
+ 'x = 1 OR (y = 1 OR z = 1)'
+
+ Args:
+ *expressions (str or Expression): the SQL code strings to parse.
+ If an Expression instance is passed, this is used as-is.
+ dialect (str): the dialect used to parse the input expression.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ Or: the new condition
+ """
+ return _combine(expressions, Or, dialect, **opts)
+
+
+def not_(expression, dialect=None, **opts):
+ """
+ Wrap a condition with a NOT operator.
+
+ Example:
+ >>> not_("this_suit='black'").sql()
+ "NOT this_suit = 'black'"
+
+ Args:
+ expression (str or Expression): the SQL code strings to parse.
+ If an Expression instance is passed, this is used as-is.
+ dialect (str): the dialect used to parse the input expression.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ Not: the new condition
+ """
+ this = condition(
+ expression,
+ dialect=dialect,
+ **opts,
+ )
+ return Not(this=_wrap_operator(this))
+
+
+def paren(expression):
+ return Paren(this=expression)
+
+
+SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$")
+
+
+def to_identifier(alias, quoted=None):
+ if alias is None:
+ return None
+ if isinstance(alias, Identifier):
+ identifier = alias
+ elif isinstance(alias, str):
+ if quoted is None:
+ quoted = not re.match(SAFE_IDENTIFIER_RE, alias)
+ identifier = Identifier(this=alias, quoted=quoted)
+ else:
+ raise ValueError(
+ f"Alias needs to be a string or an Identifier, got: {alias.__class__}"
+ )
+ return identifier
+
+
+def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts):
+ """
+ Create an Alias expression.
+ Expample:
+ >>> alias_('foo', 'bar').sql()
+ 'foo AS bar'
+
+ Args:
+ expression (str or Expression): the SQL code strings to parse.
+ If an Expression instance is passed, this is used as-is.
+ alias (str or Identifier): the alias name to use. If the name has
+ special characters it is quoted.
+ table (boolean): create a table alias, default false
+ dialect (str): the dialect used to parse the input expression.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ Alias: the aliased expression
+ """
+ exp = maybe_parse(expression, dialect=dialect, **opts)
+ alias = to_identifier(alias, quoted=quoted)
+ alias = TableAlias(this=alias) if table else alias
+
+ if "alias" in exp.arg_types:
+ exp = exp.copy()
+ exp.set("alias", alias)
+ return exp
+ return Alias(this=exp, alias=alias)
+
+
+def subquery(expression, alias=None, dialect=None, **opts):
+ """
+ Build a subquery expression.
+ Expample:
+ >>> subquery('select x from tbl', 'bar').select('x').sql()
+ 'SELECT x FROM (SELECT x FROM tbl) AS bar'
+
+ Args:
+ expression (str or Expression): the SQL code strings to parse.
+ If an Expression instance is passed, this is used as-is.
+ alias (str or Expression): the alias name to use.
+ dialect (str): the dialect used to parse the input expression.
+ **opts: other options to use to parse the input expressions.
+
+ Returns:
+ Select: a new select with the subquery expression included
+ """
+
+ expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias)
+ return Select().from_(expression, dialect=dialect, **opts)
+
+
+def column(col, table=None, quoted=None):
+ """
+ Build a Column.
+ Args:
+ col (str or Expression): column name
+ table (str or Expression): table name
+ Returns:
+ Column: column instance
+ """
+ return Column(
+ this=to_identifier(col, quoted=quoted),
+ table=to_identifier(table, quoted=quoted),
+ )
+
+
+def table_(table, db=None, catalog=None, quoted=None):
+ """
+ Build a Table.
+ Args:
+ table (str or Expression): column name
+ db (str or Expression): db name
+ catalog (str or Expression): catalog name
+ Returns:
+ Table: table instance
+ """
+ return Table(
+ this=to_identifier(table, quoted=quoted),
+ db=to_identifier(db, quoted=quoted),
+ catalog=to_identifier(catalog, quoted=quoted),
+ )
+
+
+def replace_children(expression, fun):
+ """
+ Replace children of an expression with the result of a lambda fun(child) -> exp.
+ """
+ for k, v in expression.args.items():
+ is_list_arg = isinstance(v, list)
+
+ child_nodes = v if is_list_arg else [v]
+ new_child_nodes = []
+
+ for cn in child_nodes:
+ if isinstance(cn, Expression):
+ cns = ensure_list(fun(cn))
+ for child_node in cns:
+ new_child_nodes.append(child_node)
+ child_node.parent = expression
+ child_node.arg_key = k
+ else:
+ new_child_nodes.append(cn)
+
+ expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0]
+
+
+def column_table_names(expression):
+ """
+ Return all table names referenced through columns in an expression.
+
+ Example:
+ >>> import sqlglot
+ >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e"))
+ ['c', 'a']
+
+ Args:
+ expression (sqlglot.Expression): expression to find table names
+
+ Returns:
+ list: A list of unique names
+ """
+ return list(dict.fromkeys(column.table for column in expression.find_all(Column)))
+
+
+TRUE = Boolean(this=True)
+FALSE = Boolean(this=False)
+NULL = Null()
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
new file mode 100644
index 0000000..793cff0
--- /dev/null
+++ b/sqlglot/generator.py
@@ -0,0 +1,1124 @@
+import logging
+
+from sqlglot import exp
+from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
+from sqlglot.helper import apply_index_offset, csv, ensure_list
+from sqlglot.time import format_time
+from sqlglot.tokens import TokenType
+
+logger = logging.getLogger("sqlglot")
+
+
+class Generator:
+ """
+ Generator interprets the given syntax tree and produces a SQL string as an output.
+
+ Args
+ time_mapping (dict): the dictionary of custom time mappings in which the key
+ represents a python time format and the output the target time format
+ time_trie (trie): a trie of the time_mapping keys
+ pretty (bool): if set to True the returned string will be formatted. Default: False.
+ quote_start (str): specifies which starting character to use to delimit quotes. Default: '.
+ quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
+ identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
+ identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
+ identify (bool): if set to True all identifiers will be delimited by the corresponding
+ character.
+ normalize (bool): if set to True all identifiers will lower cased
+ escape (str): specifies an escape character. Default: '.
+ pad (int): determines padding in a formatted string. Default: 2.
+ indent (int): determines the size of indentation in a formatted string. Default: 4.
+ unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
+ normalize_functions (str): normalize function names, "upper", "lower", or None
+ Default: "upper"
+ alias_post_tablesample (bool): if the table alias comes after tablesample
+ Default: False
+ unsupported_level (ErrorLevel): determines the generator's behavior when it encounters
+ unsupported expressions. Default ErrorLevel.WARN.
+ null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
+ Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
+ Default: "nulls_are_small"
+ max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError.
+ This is only relevant if unsupported_level is ErrorLevel.RAISE.
+ Default: 3
+ """
+
+ TRANSFORMS = {
+ exp.AnonymousProperty: lambda self, e: self.property_sql(e),
+ exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}",
+ exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
+ exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}",
+ exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
+ exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
+ exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}",
+ exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}",
+ exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}",
+ exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}",
+ exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}",
+ exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}",
+ exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})",
+ }
+
+ NULL_ORDERING_SUPPORTED = True
+
+ TYPE_MAPPING = {
+ exp.DataType.Type.NCHAR: "CHAR",
+ exp.DataType.Type.NVARCHAR: "VARCHAR",
+ }
+
+ TOKEN_MAPPING = {}
+
+ STRUCT_DELIMITER = ("<", ">")
+
+ ROOT_PROPERTIES = [
+ exp.AutoIncrementProperty,
+ exp.CharacterSetProperty,
+ exp.CollateProperty,
+ exp.EngineProperty,
+ exp.SchemaCommentProperty,
+ ]
+ WITH_PROPERTIES = [
+ exp.AnonymousProperty,
+ exp.FileFormatProperty,
+ exp.PartitionedByProperty,
+ exp.TableFormatProperty,
+ ]
+
+ __slots__ = (
+ "time_mapping",
+ "time_trie",
+ "pretty",
+ "configured_pretty",
+ "quote_start",
+ "quote_end",
+ "identifier_start",
+ "identifier_end",
+ "identify",
+ "normalize",
+ "escape",
+ "pad",
+ "index_offset",
+ "unnest_column_only",
+ "alias_post_tablesample",
+ "normalize_functions",
+ "unsupported_level",
+ "unsupported_messages",
+ "null_ordering",
+ "max_unsupported",
+ "_indent",
+ "_replace_backslash",
+ "_escaped_quote_end",
+ )
+
+ def __init__(
+ self,
+ time_mapping=None,
+ time_trie=None,
+ pretty=None,
+ quote_start=None,
+ quote_end=None,
+ identifier_start=None,
+ identifier_end=None,
+ identify=False,
+ normalize=False,
+ escape=None,
+ pad=2,
+ indent=2,
+ index_offset=0,
+ unnest_column_only=False,
+ alias_post_tablesample=False,
+ normalize_functions="upper",
+ unsupported_level=ErrorLevel.WARN,
+ null_ordering=None,
+ max_unsupported=3,
+ ):
+ import sqlglot
+
+ self.time_mapping = time_mapping or {}
+ self.time_trie = time_trie
+ self.pretty = pretty if pretty is not None else sqlglot.pretty
+ self.configured_pretty = self.pretty
+ self.quote_start = quote_start or "'"
+ self.quote_end = quote_end or "'"
+ self.identifier_start = identifier_start or '"'
+ self.identifier_end = identifier_end or '"'
+ self.identify = identify
+ self.normalize = normalize
+ self.escape = escape or "'"
+ self.pad = pad
+ self.index_offset = index_offset
+ self.unnest_column_only = unnest_column_only
+ self.alias_post_tablesample = alias_post_tablesample
+ self.normalize_functions = normalize_functions
+ self.unsupported_level = unsupported_level
+ self.unsupported_messages = []
+ self.max_unsupported = max_unsupported
+ self.null_ordering = null_ordering
+ self._indent = indent
+ self._replace_backslash = self.escape == "\\"
+ self._escaped_quote_end = self.escape + self.quote_end
+
+ def generate(self, expression):
+ """
+ Generates a SQL string by interpreting the given syntax tree.
+
+ Args
+ expression (Expression): the syntax tree.
+
+ Returns
+ the SQL string.
+ """
+ self.unsupported_messages = []
+ sql = self.sql(expression).strip()
+
+ if self.unsupported_level == ErrorLevel.IGNORE:
+ return sql
+
+ if self.unsupported_level == ErrorLevel.WARN:
+ for msg in self.unsupported_messages:
+ logger.warning(msg)
+ elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
+ raise UnsupportedError(
+ concat_errors(self.unsupported_messages, self.max_unsupported)
+ )
+
+ return sql
+
+ def unsupported(self, message):
+ if self.unsupported_level == ErrorLevel.IMMEDIATE:
+ raise UnsupportedError(message)
+ self.unsupported_messages.append(message)
+
+ def sep(self, sep=" "):
+ return f"{sep.strip()}\n" if self.pretty else sep
+
+ def seg(self, sql, sep=" "):
+ return f"{self.sep(sep)}{sql}"
+
+ def wrap(self, expression):
+ this_sql = self.indent(
+ self.sql(expression)
+ if isinstance(expression, (exp.Select, exp.Union))
+ else self.sql(expression, "this"),
+ level=1,
+ pad=0,
+ )
+ return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}"
+
+ def no_identify(self, func):
+ original = self.identify
+ self.identify = False
+ result = func()
+ self.identify = original
+ return result
+
+ def normalize_func(self, name):
+ if self.normalize_functions == "upper":
+ return name.upper()
+ if self.normalize_functions == "lower":
+ return name.lower()
+ return name
+
+ def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False):
+ if not self.pretty:
+ return sql
+
+ pad = self.pad if pad is None else pad
+ lines = sql.split("\n")
+
+ return "\n".join(
+ line
+ if (skip_first and i == 0) or (skip_last and i == len(lines) - 1)
+ else f"{' ' * (level * self._indent + pad)}{line}"
+ for i, line in enumerate(lines)
+ )
+
+ def sql(self, expression, key=None):
+ if not expression:
+ return ""
+
+ if isinstance(expression, str):
+ return expression
+
+ if key:
+ return self.sql(expression.args.get(key))
+
+ transform = self.TRANSFORMS.get(expression.__class__)
+
+ if callable(transform):
+ return transform(self, expression)
+ if transform:
+ return transform
+
+ if not isinstance(expression, exp.Expression):
+ raise ValueError(
+ f"Expected an Expression. Received {type(expression)}: {expression}"
+ )
+
+ exp_handler_name = f"{expression.key}_sql"
+ if hasattr(self, exp_handler_name):
+ return getattr(self, exp_handler_name)(expression)
+
+ if isinstance(expression, exp.Func):
+ return self.function_fallback_sql(expression)
+
+ raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
+
+ def annotation_sql(self, expression):
+ return self.sql(expression, "expression")
+
+ def uncache_sql(self, expression):
+ table = self.sql(expression, "this")
+ exists_sql = " IF EXISTS" if expression.args.get("exists") else ""
+ return f"UNCACHE TABLE{exists_sql} {table}"
+
+ def cache_sql(self, expression):
+ lazy = " LAZY" if expression.args.get("lazy") else ""
+ table = self.sql(expression, "this")
+ options = expression.args.get("options")
+ options = (
+ f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})"
+ if options
+ else ""
+ )
+ sql = self.sql(expression, "expression")
+ sql = f" AS{self.sep()}{sql}" if sql else ""
+ sql = f"CACHE{lazy} TABLE {table}{options}{sql}"
+ return self.prepend_ctes(expression, sql)
+
+ def characterset_sql(self, expression):
+ if isinstance(expression.parent, exp.Cast):
+ return f"CHAR CHARACTER SET {self.sql(expression, 'this')}"
+ default = "DEFAULT " if expression.args.get("default") else ""
+ return f"{default}CHARACTER SET={self.sql(expression, 'this')}"
+
+ def column_sql(self, expression):
+ return ".".join(
+ part
+ for part in [
+ self.sql(expression, "db"),
+ self.sql(expression, "table"),
+ self.sql(expression, "this"),
+ ]
+ if part
+ )
+
+ def columndef_sql(self, expression):
+ column = self.sql(expression, "this")
+ kind = self.sql(expression, "kind")
+ constraints = self.expressions(
+ expression, key="constraints", sep=" ", flat=True
+ )
+
+ if not constraints:
+ return f"{column} {kind}"
+ return f"{column} {kind} {constraints}"
+
+ def columnconstraint_sql(self, expression):
+ this = self.sql(expression, "this")
+ kind_sql = self.sql(expression, "kind")
+ return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql
+
+ def autoincrementcolumnconstraint_sql(self, _):
+ return self.token_sql(TokenType.AUTO_INCREMENT)
+
+ def checkcolumnconstraint_sql(self, expression):
+ this = self.sql(expression, "this")
+ return f"CHECK ({this})"
+
+ def commentcolumnconstraint_sql(self, expression):
+ comment = self.sql(expression, "this")
+ return f"COMMENT {comment}"
+
+ def collatecolumnconstraint_sql(self, expression):
+ collate = self.sql(expression, "this")
+ return f"COLLATE {collate}"
+
+ def defaultcolumnconstraint_sql(self, expression):
+ default = self.sql(expression, "this")
+ return f"DEFAULT {default}"
+
+ def notnullcolumnconstraint_sql(self, _):
+ return "NOT NULL"
+
+ def primarykeycolumnconstraint_sql(self, _):
+ return "PRIMARY KEY"
+
+ def uniquecolumnconstraint_sql(self, _):
+ return "UNIQUE"
+
+ def create_sql(self, expression):
+ this = self.sql(expression, "this")
+ kind = self.sql(expression, "kind").upper()
+ expression_sql = self.sql(expression, "expression")
+ expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
+ temporary = " TEMPORARY" if expression.args.get("temporary") else ""
+ replace = " OR REPLACE" if expression.args.get("replace") else ""
+ exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
+ unique = " UNIQUE" if expression.args.get("unique") else ""
+ properties = self.sql(expression, "properties")
+
+ expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}"
+ return self.prepend_ctes(expression, expression_sql)
+
+ def prepend_ctes(self, expression, sql):
+ with_ = self.sql(expression, "with")
+ if with_:
+ sql = f"{with_}{self.sep()}{sql}"
+ return sql
+
+ def with_sql(self, expression):
+ sql = self.expressions(expression, flat=True)
+ recursive = "RECURSIVE " if expression.args.get("recursive") else ""
+
+ return f"WITH {recursive}{sql}"
+
+ def cte_sql(self, expression):
+ alias = self.sql(expression, "alias")
+ return f"{alias} AS {self.wrap(expression)}"
+
+ def tablealias_sql(self, expression):
+ alias = self.sql(expression, "this")
+ columns = self.expressions(expression, key="columns", flat=True)
+ columns = f"({columns})" if columns else ""
+ return f"{alias}{columns}"
+
+ def bitstring_sql(self, expression):
+ return f"b'{self.sql(expression, 'this')}'"
+
+ def datatype_sql(self, expression):
+ type_value = expression.this
+ type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
+ nested = ""
+ interior = self.expressions(expression, flat=True)
+ if interior:
+ nested = (
+ f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
+ if expression.args.get("nested")
+ else f"({interior})"
+ )
+ return f"{type_sql}{nested}"
+
+ def delete_sql(self, expression):
+ this = self.sql(expression, "this")
+ where_sql = self.sql(expression, "where")
+ sql = f"DELETE FROM {this}{where_sql}"
+ return self.prepend_ctes(expression, sql)
+
+ def drop_sql(self, expression):
+ this = self.sql(expression, "this")
+ kind = expression.args["kind"]
+ exists_sql = " IF EXISTS " if expression.args.get("exists") else " "
+ return f"DROP {kind}{exists_sql}{this}"
+
+ def except_sql(self, expression):
+ return self.prepend_ctes(
+ expression,
+ self.set_operation(expression, self.except_op(expression)),
+ )
+
+ def except_op(self, expression):
+ return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}"
+
+ def fetch_sql(self, expression):
+ direction = expression.args.get("direction")
+ direction = f" {direction.upper()}" if direction else ""
+ count = expression.args.get("count")
+ count = f" {count}" if count else ""
+ return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
+
+ def filter_sql(self, expression):
+ this = self.sql(expression, "this")
+ where = self.sql(expression, "expression")[1:] # where has a leading space
+ return f"{this} FILTER({where})"
+
+ def hint_sql(self, expression):
+ if self.sql(expression, "this"):
+ self.unsupported("Hints are not supported")
+ return ""
+
+ def index_sql(self, expression):
+ this = self.sql(expression, "this")
+ table = self.sql(expression, "table")
+ columns = self.sql(expression, "columns")
+ return f"{this} ON {table} {columns}"
+
+ def identifier_sql(self, expression):
+ value = expression.name
+ value = value.lower() if self.normalize else value
+ if expression.args.get("quoted") or self.identify:
+ return f"{self.identifier_start}{value}{self.identifier_end}"
+ return value
+
+ def partition_sql(self, expression):
+ keys = csv(
+ *[
+ f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"]
+ for k, v in expression.args.get("this")
+ ]
+ )
+ return f"PARTITION({keys})"
+
+ def properties_sql(self, expression):
+ root_properties = []
+ with_properties = []
+
+ for p in expression.expressions:
+ p_class = p.__class__
+ if p_class in self.ROOT_PROPERTIES:
+ root_properties.append(p)
+ elif p_class in self.WITH_PROPERTIES:
+ with_properties.append(p)
+
+ return self.root_properties(
+ exp.Properties(expressions=root_properties)
+ ) + self.with_properties(exp.Properties(expressions=with_properties))
+
+ def root_properties(self, properties):
+ if properties.expressions:
+ return self.sep() + self.expressions(
+ properties,
+ indent=False,
+ sep=" ",
+ )
+ return ""
+
+ def properties(self, properties, prefix="", sep=", "):
+ if properties.expressions:
+ expressions = self.expressions(
+ properties,
+ sep=sep,
+ indent=False,
+ )
+ return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
+ return ""
+
+ def with_properties(self, properties):
+ return self.properties(
+ properties,
+ prefix="WITH",
+ )
+
+ def property_sql(self, expression):
+ key = expression.name
+ value = self.sql(expression, "value")
+ return f"{key} = {value}"
+
+ def insert_sql(self, expression):
+ kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO"
+ this = self.sql(expression, "this")
+ exists = " IF EXISTS " if expression.args.get("exists") else " "
+ partition_sql = (
+ self.sql(expression, "partition")
+ if expression.args.get("partition")
+ else ""
+ )
+ expression_sql = self.sql(expression, "expression")
+ sep = self.sep() if partition_sql else ""
+ sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}"
+ return self.prepend_ctes(expression, sql)
+
+ def intersect_sql(self, expression):
+ return self.prepend_ctes(
+ expression,
+ self.set_operation(expression, self.intersect_op(expression)),
+ )
+
+ def intersect_op(self, expression):
+ return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}"
+
+ def introducer_sql(self, expression):
+ return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
+
+ def table_sql(self, expression):
+ return ".".join(
+ part
+ for part in [
+ self.sql(expression, "catalog"),
+ self.sql(expression, "db"),
+ self.sql(expression, "this"),
+ ]
+ if part
+ )
+
+ def tablesample_sql(self, expression):
+ if self.alias_post_tablesample and isinstance(expression.this, exp.Alias):
+ this = self.sql(expression.this, "this")
+ alias = f" AS {self.sql(expression.this, 'alias')}"
+ else:
+ this = self.sql(expression, "this")
+ alias = ""
+ method = self.sql(expression, "method")
+ method = f" {method.upper()} " if method else ""
+ numerator = self.sql(expression, "bucket_numerator")
+ denominator = self.sql(expression, "bucket_denominator")
+ field = self.sql(expression, "bucket_field")
+ field = f" ON {field}" if field else ""
+ bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else ""
+ percent = self.sql(expression, "percent")
+ percent = f"{percent} PERCENT" if percent else ""
+ rows = self.sql(expression, "rows")
+ rows = f"{rows} ROWS" if rows else ""
+ size = self.sql(expression, "size")
+ return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}"
+
+ def tuple_sql(self, expression):
+ return f"({self.expressions(expression, flat=True)})"
+
+ def update_sql(self, expression):
+ this = self.sql(expression, "this")
+ set_sql = self.expressions(expression, flat=True)
+ from_sql = self.sql(expression, "from")
+ where_sql = self.sql(expression, "where")
+ sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}"
+ return self.prepend_ctes(expression, sql)
+
+ def values_sql(self, expression):
+ return f"VALUES{self.seg('')}{self.expressions(expression)}"
+
+ def var_sql(self, expression):
+ return self.sql(expression, "this")
+
+ def from_sql(self, expression):
+ expressions = self.expressions(expression, flat=True)
+ return f"{self.seg('FROM')} {expressions}"
+
+ def group_sql(self, expression):
+ group_by = self.op_expressions("GROUP BY", expression)
+ grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
+ grouping_sets = (
+ f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}"
+ if grouping_sets
+ else ""
+ )
+ cube = self.expressions(expression, key="cube", indent=False)
+ cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
+ rollup = self.expressions(expression, key="rollup", indent=False)
+ rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else ""
+ return f"{group_by}{grouping_sets}{cube}{rollup}"
+
+ def having_sql(self, expression):
+ this = self.indent(self.sql(expression, "this"))
+ return f"{self.seg('HAVING')}{self.sep()}{this}"
+
+ def join_sql(self, expression):
+ op_sql = self.seg(
+ " ".join(op for op in (expression.side, expression.kind, "JOIN") if op)
+ )
+ on_sql = self.sql(expression, "on")
+ using = expression.args.get("using")
+
+ if not on_sql and using:
+ on_sql = csv(*(self.sql(column) for column in using))
+
+ if on_sql:
+ on_sql = self.indent(on_sql, skip_first=True)
+ space = self.seg(" " * self.pad) if self.pretty else " "
+ if using:
+ on_sql = f"{space}USING ({on_sql})"
+ else:
+ on_sql = f"{space}ON {on_sql}"
+
+ expression_sql = self.sql(expression, "expression")
+ this_sql = self.sql(expression, "this")
+ return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
+
+ def lambda_sql(self, expression):
+ args = self.expressions(expression, flat=True)
+ args = f"({args})" if len(args.split(",")) > 1 else args
+ return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
+
+ def lateral_sql(self, expression):
+ this = self.sql(expression, "this")
+ op_sql = self.seg(
+ f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}"
+ )
+ alias = expression.args["alias"]
+ table = alias.name
+ table = f" {table}" if table else table
+ columns = self.expressions(alias, key="columns", flat=True)
+ columns = f" AS {columns}" if columns else ""
+ return f"{op_sql}{self.sep()}{this}{table}{columns}"
+
+ def limit_sql(self, expression):
+ this = self.sql(expression, "this")
+ return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}"
+
+ def offset_sql(self, expression):
+ this = self.sql(expression, "this")
+ return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}"
+
+ def literal_sql(self, expression):
+ text = expression.this or ""
+ if expression.is_string:
+ if self._replace_backslash:
+ text = text.replace("\\", "\\\\")
+ text = text.replace(self.quote_end, self._escaped_quote_end)
+ return f"{self.quote_start}{text}{self.quote_end}"
+ return text
+
+ def null_sql(self, *_):
+ return "NULL"
+
+ def boolean_sql(self, expression):
+ return "TRUE" if expression.this else "FALSE"
+
+ def order_sql(self, expression, flat=False):
+ this = self.sql(expression, "this")
+ this = f"{this} " if this else this
+ return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat)
+
+ def cluster_sql(self, expression):
+ return self.op_expressions("CLUSTER BY", expression)
+
+ def distribute_sql(self, expression):
+ return self.op_expressions("DISTRIBUTE BY", expression)
+
+ def sort_sql(self, expression):
+ return self.op_expressions("SORT BY", expression)
+
+ def ordered_sql(self, expression):
+ desc = expression.args.get("desc")
+ asc = not desc
+ nulls_first = expression.args.get("nulls_first")
+ nulls_last = not nulls_first
+ nulls_are_large = self.null_ordering == "nulls_are_large"
+ nulls_are_small = self.null_ordering == "nulls_are_small"
+ nulls_are_last = self.null_ordering == "nulls_are_last"
+
+ sort_order = " DESC" if desc else ""
+ nulls_sort_change = ""
+ if nulls_first and (
+ (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
+ ):
+ nulls_sort_change = " NULLS FIRST"
+ elif (
+ nulls_last
+ and ((asc and nulls_are_small) or (desc and nulls_are_large))
+ and not nulls_are_last
+ ):
+ nulls_sort_change = " NULLS LAST"
+
+ if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
+ self.unsupported(
+ "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
+ )
+ nulls_sort_change = ""
+
+ return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
+
+ def query_modifiers(self, expression, *sqls):
+ return csv(
+ *sqls,
+ *[self.sql(sql) for sql in expression.args.get("laterals", [])],
+ *[self.sql(sql) for sql in expression.args.get("joins", [])],
+ self.sql(expression, "where"),
+ self.sql(expression, "group"),
+ self.sql(expression, "having"),
+ self.sql(expression, "qualify"),
+ self.sql(expression, "window"),
+ self.sql(expression, "distribute"),
+ self.sql(expression, "sort"),
+ self.sql(expression, "cluster"),
+ self.sql(expression, "order"),
+ self.sql(expression, "limit"),
+ self.sql(expression, "offset"),
+ sep="",
+ )
+
+ def select_sql(self, expression):
+ hint = self.sql(expression, "hint")
+ distinct = self.sql(expression, "distinct")
+ distinct = f" {distinct}" if distinct else ""
+ expressions = self.expressions(expression)
+ expressions = f"{self.sep()}{expressions}" if expressions else expressions
+ sql = self.query_modifiers(
+ expression,
+ f"SELECT{hint}{distinct}{expressions}",
+ self.sql(expression, "from"),
+ )
+ return self.prepend_ctes(expression, sql)
+
+ def schema_sql(self, expression):
+ this = self.sql(expression, "this")
+ this = f"{this} " if this else ""
+ sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}"
+ return f"{this}{sql}"
+
+ def star_sql(self, expression):
+ except_ = self.expressions(expression, key="except", flat=True)
+ except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
+ replace = self.expressions(expression, key="replace", flat=True)
+ replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
+ return f"*{except_}{replace}"
+
+ def structkwarg_sql(self, expression):
+ return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}"
+
+ def placeholder_sql(self, *_):
+ return "?"
+
+ def subquery_sql(self, expression):
+ alias = self.sql(expression, "alias")
+
+ return self.query_modifiers(
+ expression,
+ self.wrap(expression),
+ f" AS {alias}" if alias else "",
+ )
+
+ def qualify_sql(self, expression):
+ this = self.indent(self.sql(expression, "this"))
+ return f"{self.seg('QUALIFY')}{self.sep()}{this}"
+
+ def union_sql(self, expression):
+ return self.prepend_ctes(
+ expression,
+ self.set_operation(expression, self.union_op(expression)),
+ )
+
+ def union_op(self, expression):
+ return f"UNION{'' if expression.args.get('distinct') else ' ALL'}"
+
+ def unnest_sql(self, expression):
+ args = self.expressions(expression, flat=True)
+ alias = expression.args.get("alias")
+ if alias and self.unnest_column_only:
+ columns = alias.columns
+ alias = self.sql(columns[0]) if columns else ""
+ else:
+ alias = self.sql(expression, "alias")
+ alias = f" AS {alias}" if alias else alias
+ ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
+ return f"UNNEST({args}){ordinality}{alias}"
+
+ def where_sql(self, expression):
+ this = self.indent(self.sql(expression, "this"))
+ return f"{self.seg('WHERE')}{self.sep()}{this}"
+
+ def window_sql(self, expression):
+ this = self.sql(expression, "this")
+ partition = self.expressions(expression, key="partition_by", flat=True)
+ partition = f"PARTITION BY {partition}" if partition else ""
+ order = expression.args.get("order")
+ order_sql = self.order_sql(order, flat=True) if order else ""
+ partition_sql = partition + " " if partition and order else partition
+ spec = expression.args.get("spec")
+ spec_sql = " " + self.window_spec_sql(spec) if spec else ""
+ alias = self.sql(expression, "alias")
+ if expression.arg_key == "window":
+ this = this = f"{self.seg('WINDOW')} {this} AS"
+ else:
+ this = f"{this} OVER"
+
+ if not partition and not order and not spec and alias:
+ return f"{this} {alias}"
+
+ return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})"
+
+ def window_spec_sql(self, expression):
+ kind = self.sql(expression, "kind")
+ start = csv(
+ self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" "
+ )
+ end = (
+ csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
+ or "CURRENT ROW"
+ )
+ return f"{kind} BETWEEN {start} AND {end}"
+
+ def withingroup_sql(self, expression):
+ this = self.sql(expression, "this")
+ expression = self.sql(expression, "expression")[1:] # order has a leading space
+ return f"{this} WITHIN GROUP ({expression})"
+
+ def between_sql(self, expression):
+ this = self.sql(expression, "this")
+ low = self.sql(expression, "low")
+ high = self.sql(expression, "high")
+ return f"{this} BETWEEN {low} AND {high}"
+
+ def bracket_sql(self, expression):
+ expressions = apply_index_offset(expression.expressions, self.index_offset)
+ expressions = ", ".join(self.sql(e) for e in expressions)
+
+ return f"{self.sql(expression, 'this')}[{expressions}]"
+
+ def all_sql(self, expression):
+ return f"ALL {self.wrap(expression)}"
+
+ def any_sql(self, expression):
+ return f"ANY {self.wrap(expression)}"
+
+ def exists_sql(self, expression):
+ return f"EXISTS{self.wrap(expression)}"
+
+ def case_sql(self, expression):
+ this = self.indent(self.sql(expression, "this"), skip_first=True)
+ this = f" {this}" if this else ""
+ ifs = []
+
+ for e in expression.args["ifs"]:
+ ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}"))
+ ifs.append(self.indent(f"THEN {self.sql(e, 'true')}"))
+
+ if expression.args.get("default") is not None:
+ ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}"))
+
+ ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs)
+ statement = f"CASE{this}{ifs}{self.seg('END')}"
+ return statement
+
+ def constraint_sql(self, expression):
+ this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True)
+ return f"CONSTRAINT {this} {expressions}"
+
+ def extract_sql(self, expression):
+ this = self.sql(expression, "this")
+ expression_sql = self.sql(expression, "expression")
+ return f"EXTRACT({this} FROM {expression_sql})"
+
+ def check_sql(self, expression):
+ this = self.sql(expression, key="this")
+ return f"CHECK ({this})"
+
+ def foreignkey_sql(self, expression):
+ expressions = self.expressions(expression, flat=True)
+ reference = self.sql(expression, "reference")
+ reference = f" {reference}" if reference else ""
+ delete = self.sql(expression, "delete")
+ delete = f" ON DELETE {delete}" if delete else ""
+ update = self.sql(expression, "update")
+ update = f" ON UPDATE {update}" if update else ""
+ return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
+
+ def unique_sql(self, expression):
+ columns = self.expressions(expression, key="expressions")
+ return f"UNIQUE ({columns})"
+
+ def if_sql(self, expression):
+ return self.case_sql(
+ exp.Case(ifs=[expression], default=expression.args.get("false"))
+ )
+
+ def in_sql(self, expression):
+ query = expression.args.get("query")
+ unnest = expression.args.get("unnest")
+ if query:
+ in_sql = self.wrap(query)
+ elif unnest:
+ in_sql = self.in_unnest_op(unnest)
+ else:
+ in_sql = f"({self.expressions(expression, flat=True)})"
+ return f"{self.sql(expression, 'this')} IN {in_sql}"
+
+ def in_unnest_op(self, unnest):
+ return f"(SELECT {self.sql(unnest)})"
+
+ def interval_sql(self, expression):
+ return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}"
+
+ def reference_sql(self, expression):
+ this = self.sql(expression, "this")
+ expressions = self.expressions(expression, flat=True)
+ return f"REFERENCES {this}({expressions})"
+
+ def anonymous_sql(self, expression):
+ args = self.indent(
+ self.expressions(expression, flat=True), skip_first=True, skip_last=True
+ )
+ return f"{self.normalize_func(self.sql(expression, 'this'))}({args})"
+
+ def paren_sql(self, expression):
+ if isinstance(expression.unnest(), exp.Select):
+ return self.wrap(expression)
+ sql = self.seg(self.indent(self.sql(expression, "this")), sep="")
+ return f"({sql}{self.seg(')', sep='')}"
+
+ def neg_sql(self, expression):
+ return f"-{self.sql(expression, 'this')}"
+
+ def not_sql(self, expression):
+ return f"NOT {self.sql(expression, 'this')}"
+
+ def alias_sql(self, expression):
+ to_sql = self.sql(expression, "alias")
+ to_sql = f" AS {to_sql}" if to_sql else ""
+ return f"{self.sql(expression, 'this')}{to_sql}"
+
+ def aliases_sql(self, expression):
+ return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
+
+ def attimezone_sql(self, expression):
+ this = self.sql(expression, "this")
+ zone = self.sql(expression, "zone")
+ return f"{this} AT TIME ZONE {zone}"
+
+ def add_sql(self, expression):
+ return self.binary(expression, "+")
+
+ def and_sql(self, expression):
+ return self.connector_sql(expression, "AND")
+
+ def connector_sql(self, expression, op):
+ if not self.pretty:
+ return self.binary(expression, op)
+
+ return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False))
+
+ def bitwiseand_sql(self, expression):
+ return self.binary(expression, "&")
+
+ def bitwiseleftshift_sql(self, expression):
+ return self.binary(expression, "<<")
+
+ def bitwisenot_sql(self, expression):
+ return f"~{self.sql(expression, 'this')}"
+
+ def bitwiseor_sql(self, expression):
+ return self.binary(expression, "|")
+
+ def bitwiserightshift_sql(self, expression):
+ return self.binary(expression, ">>")
+
+ def bitwisexor_sql(self, expression):
+ return self.binary(expression, "^")
+
+ def cast_sql(self, expression):
+ return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
+
+ def currentdate_sql(self, expression):
+ zone = self.sql(expression, "this")
+ return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE"
+
+ def command_sql(self, expression):
+ return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}"
+
+ def distinct_sql(self, expression):
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+
+ on = self.sql(expression, "on")
+ on = f" ON {on}" if on else ""
+ return f"DISTINCT{this}{on}"
+
+ def ignorenulls_sql(self, expression):
+ return f"{self.sql(expression, 'this')} IGNORE NULLS"
+
+ def intdiv_sql(self, expression):
+ return self.sql(
+ exp.Cast(
+ this=exp.Div(
+ this=expression.args["this"],
+ expression=expression.args["expression"],
+ ),
+ to=exp.DataType(this=exp.DataType.Type.INT),
+ )
+ )
+
+ def dpipe_sql(self, expression):
+ return self.binary(expression, "||")
+
+ def div_sql(self, expression):
+ return self.binary(expression, "/")
+
+ def dot_sql(self, expression):
+ return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}"
+
+ def eq_sql(self, expression):
+ return self.binary(expression, "=")
+
+ def escape_sql(self, expression):
+ return self.binary(expression, "ESCAPE")
+
+ def gt_sql(self, expression):
+ return self.binary(expression, ">")
+
+ def gte_sql(self, expression):
+ return self.binary(expression, ">=")
+
+ def ilike_sql(self, expression):
+ return self.binary(expression, "ILIKE")
+
+ def is_sql(self, expression):
+ return self.binary(expression, "IS")
+
+ def like_sql(self, expression):
+ return self.binary(expression, "LIKE")
+
+ def lt_sql(self, expression):
+ return self.binary(expression, "<")
+
+ def lte_sql(self, expression):
+ return self.binary(expression, "<=")
+
+ def mod_sql(self, expression):
+ return self.binary(expression, "%")
+
+ def mul_sql(self, expression):
+ return self.binary(expression, "*")
+
+ def neq_sql(self, expression):
+ return self.binary(expression, "<>")
+
+ def or_sql(self, expression):
+ return self.connector_sql(expression, "OR")
+
+ def sub_sql(self, expression):
+ return self.binary(expression, "-")
+
+ def trycast_sql(self, expression):
+ return (
+ f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
+ )
+
+ def binary(self, expression, op):
+ return (
+ f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
+ )
+
+ def function_fallback_sql(self, expression):
+ args = []
+ for arg_key in expression.arg_types:
+ arg_value = ensure_list(expression.args.get(arg_key) or [])
+ for a in arg_value:
+ args.append(self.sql(a))
+
+ args_str = self.indent(", ".join(args), skip_first=True, skip_last=True)
+ return f"{self.normalize_func(expression.sql_name())}({args_str})"
+
+ def format_time(self, expression):
+ return format_time(
+ self.sql(expression, "format"), self.time_mapping, self.time_trie
+ )
+
+ def expressions(self, expression, key=None, flat=False, indent=True, sep=", "):
+ expressions = expression.args.get(key or "expressions")
+
+ if not expressions:
+ return ""
+
+ if flat:
+ return sep.join(self.sql(e) for e in expressions)
+
+ expressions = self.sep(sep).join(self.sql(e) for e in expressions)
+ if indent:
+ return self.indent(expressions, skip_first=False)
+ return expressions
+
+ def op_expressions(self, op, expression, flat=False):
+ expressions_sql = self.expressions(expression, flat=flat)
+ if flat:
+ return f"{op} {expressions_sql}"
+ return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
+
+ def set_operation(self, expression, op):
+ this = self.sql(expression, "this")
+ op = self.seg(op)
+ return self.query_modifiers(
+ expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
+ )
+
+ def token_sql(self, token_type):
+ return self.TOKEN_MAPPING.get(token_type, token_type.name)
diff --git a/sqlglot/helper.py b/sqlglot/helper.py
new file mode 100644
index 0000000..5d90c49
--- /dev/null
+++ b/sqlglot/helper.py
@@ -0,0 +1,123 @@
+import logging
+import re
+from contextlib import contextmanager
+from enum import Enum
+
+CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])")
+logger = logging.getLogger("sqlglot")
+
+
+class AutoName(Enum):
+ def _generate_next_value_(name, _start, _count, _last_values):
+ return name
+
+
+def list_get(arr, index):
+ try:
+ return arr[index]
+ except IndexError:
+ return None
+
+
+def ensure_list(value):
+ if value is None:
+ return []
+ return value if isinstance(value, (list, tuple, set)) else [value]
+
+
+def csv(*args, sep=", "):
+ return sep.join(arg for arg in args if arg)
+
+
+def apply_index_offset(expressions, offset):
+ if not offset or len(expressions) != 1:
+ return expressions
+
+ expression = expressions[0]
+
+ if expression.is_int:
+ expression = expression.copy()
+ logger.warning("Applying array index offset (%s)", offset)
+ expression.args["this"] = str(int(expression.args["this"]) + offset)
+ return [expression]
+ return expressions
+
+
+def camel_to_snake_case(name):
+ return CAMEL_CASE_PATTERN.sub("_", name).upper()
+
+
+def while_changing(expression, func):
+ while True:
+ start = hash(expression)
+ expression = func(expression)
+ if start == hash(expression):
+ break
+ return expression
+
+
+def tsort(dag):
+ result = []
+
+ def visit(node, visited):
+ if node in result:
+ return
+ if node in visited:
+ raise ValueError("Cycle error")
+
+ visited.add(node)
+
+ for dep in dag.get(node, []):
+ visit(dep, visited)
+
+ visited.remove(node)
+ result.append(node)
+
+ for node in dag:
+ visit(node, set())
+
+ return result
+
+
+def open_file(file_name):
+ """
+ Open a file that may be compressed as gzip and return in newline mode.
+ """
+ with open(file_name, "rb") as f:
+ gzipped = f.read(2) == b"\x1f\x8b"
+
+ if gzipped:
+ import gzip
+
+ return gzip.open(file_name, "rt", newline="")
+
+ return open(file_name, "rt", encoding="utf-8", newline="")
+
+
+@contextmanager
+def csv_reader(table):
+ """
+ Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...])
+
+ Args:
+ expression (Expression): An anonymous function READ_CSV
+
+ Returns:
+ A python csv reader.
+ """
+ file, *args = table.this.expressions
+ file = file.name
+ file = open_file(file)
+
+ delimiter = ","
+ args = iter(arg.name for arg in args)
+ for k, v in zip(args, args):
+ if k == "delimiter":
+ delimiter = v
+
+ try:
+ import csv as csv_
+
+ yield csv_.reader(file, delimiter=delimiter)
+ finally:
+ file.close()
diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py
new file mode 100644
index 0000000..a4c4cc2
--- /dev/null
+++ b/sqlglot/optimizer/__init__.py
@@ -0,0 +1,2 @@
+from sqlglot.optimizer.optimizer import optimize
+from sqlglot.optimizer.schema import Schema
diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py
new file mode 100644
index 0000000..4bfb733
--- /dev/null
+++ b/sqlglot/optimizer/eliminate_subqueries.py
@@ -0,0 +1,48 @@
+import itertools
+
+from sqlglot import alias, exp, select, table
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def eliminate_subqueries(expression):
+ """
+ Rewrite duplicate subqueries from sqlglot AST.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y")
+ >>> eliminate_subqueries(expression).sql()
+ 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ schema (dict|sqlglot.optimizer.Schema): Database schema
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ expression = simplify(expression)
+ queries = {}
+
+ for scope in traverse_scope(expression):
+ query = scope.expression
+ queries[query] = queries.get(query, []) + [query]
+
+ sequence = itertools.count()
+
+ for query, duplicates in queries.items():
+ if len(duplicates) == 1:
+ continue
+
+ alias_ = f"_e_{next(sequence)}"
+
+ for dup in duplicates:
+ parent = dup.parent
+ if isinstance(parent, exp.Subquery):
+ parent.replace(alias(table(alias_), parent.alias_or_name, table=True))
+ elif isinstance(parent, exp.Union):
+ dup.replace(select("*").from_(alias_))
+
+ expression.with_(alias_, as_=query, copy=False)
+
+ return expression
diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py
new file mode 100644
index 0000000..ba562df
--- /dev/null
+++ b/sqlglot/optimizer/expand_multi_table_selects.py
@@ -0,0 +1,16 @@
+from sqlglot import exp
+
+
+def expand_multi_table_selects(expression):
+ for from_ in expression.find_all(exp.From):
+ parent = from_.parent
+
+ for query in from_.expressions[1:]:
+ parent.join(
+ query,
+ join_type="CROSS",
+ copy=False,
+ )
+ from_.expressions.remove(query)
+
+ return expression
diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py
new file mode 100644
index 0000000..c2e021e
--- /dev/null
+++ b/sqlglot/optimizer/isolate_table_selects.py
@@ -0,0 +1,31 @@
+from sqlglot import alias, exp
+from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def isolate_table_selects(expression):
+ for scope in traverse_scope(expression):
+ if len(scope.selected_sources) == 1:
+ continue
+
+ for (_, source) in scope.selected_sources.values():
+ if not isinstance(source, exp.Table):
+ continue
+
+ if not isinstance(source.parent, exp.Alias):
+ raise OptimizeError(
+ "Tables require an alias. Run qualify_tables optimization."
+ )
+
+ parent = source.parent
+
+ parent.replace(
+ exp.select("*")
+ .from_(
+ alias(source, source.name or parent.alias, table=True),
+ copy=False,
+ )
+ .subquery(parent.alias, copy=False)
+ )
+
+ return expression
diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py
new file mode 100644
index 0000000..2c9f89c
--- /dev/null
+++ b/sqlglot/optimizer/normalize.py
@@ -0,0 +1,136 @@
+from sqlglot import exp
+from sqlglot.helper import while_changing
+from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort
+
+
+def normalize(expression, dnf=False, max_distance=128):
+ """
+ Rewrite sqlglot AST into conjunctive normal form.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("(x AND y) OR z")
+ >>> normalize(expression).sql()
+ '(x OR z) AND (y OR z)'
+
+ Args:
+ expression (sqlglot.Expression): expression to normalize
+ dnf (bool): rewrite in disjunctive normal form instead
+ max_distance (int): the maximal estimated distance from cnf to attempt conversion
+ Returns:
+ sqlglot.Expression: normalized expression
+ """
+ expression = simplify(expression)
+
+ expression = while_changing(
+ expression, lambda e: distributive_law(e, dnf, max_distance)
+ )
+ return simplify(expression)
+
+
+def normalized(expression, dnf=False):
+ ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And)
+
+ return not any(
+ connector.find_ancestor(ancestor) for connector in expression.find_all(root)
+ )
+
+
+def normalization_distance(expression, dnf=False):
+ """
+ The difference in the number of predicates between the current expression and the normalized form.
+
+ This is used as an estimate of the cost of the conversion which is exponential in complexity.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)")
+ >>> normalization_distance(expression)
+ 4
+
+ Args:
+ expression (sqlglot.Expression): expression to compute distance
+ dnf (bool): compute to dnf distance instead
+ Returns:
+ int: difference
+ """
+ return sum(_predicate_lengths(expression, dnf)) - (
+ len(list(expression.find_all(exp.Connector))) + 1
+ )
+
+
+def _predicate_lengths(expression, dnf):
+ """
+ Returns a list of predicate lengths when expanded to normalized form.
+
+ (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C).
+ """
+ expression = expression.unnest()
+
+ if not isinstance(expression, exp.Connector):
+ return [1]
+
+ left, right = expression.args.values()
+
+ if isinstance(expression, exp.And if dnf else exp.Or):
+ x = [
+ a + b
+ for a in _predicate_lengths(left, dnf)
+ for b in _predicate_lengths(right, dnf)
+ ]
+ return x
+ return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf)
+
+
+def distributive_law(expression, dnf, max_distance):
+ """
+ x OR (y AND z) -> (x OR y) AND (x OR z)
+ (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
+ """
+ if isinstance(expression.unnest(), exp.Connector):
+ if normalization_distance(expression, dnf) > max_distance:
+ return expression
+
+ to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or)
+
+ exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance))
+
+ if isinstance(expression, from_exp):
+ a, b = expression.unnest_operands()
+
+ from_func = exp.and_ if from_exp == exp.And else exp.or_
+ to_func = exp.and_ if to_exp == exp.And else exp.or_
+
+ if isinstance(a, to_exp) and isinstance(b, to_exp):
+ if len(tuple(a.find_all(exp.Connector))) > len(
+ tuple(b.find_all(exp.Connector))
+ ):
+ return _distribute(a, b, from_func, to_func)
+ return _distribute(b, a, from_func, to_func)
+ if isinstance(a, to_exp):
+ return _distribute(b, a, from_func, to_func)
+ if isinstance(b, to_exp):
+ return _distribute(a, b, from_func, to_func)
+
+ return expression
+
+
+def _distribute(a, b, from_func, to_func):
+ if isinstance(a, exp.Connector):
+ exp.replace_children(
+ a,
+ lambda c: to_func(
+ exp.paren(from_func(c, b.left)),
+ exp.paren(from_func(c, b.right)),
+ ),
+ )
+ else:
+ a = to_func(from_func(a, b.left), from_func(a, b.right))
+
+ return _simplify(a)
+
+
+def _simplify(node):
+ node = uniq_sort(flatten(node))
+ exp.replace_children(node, _simplify)
+ return node
diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py
new file mode 100644
index 0000000..40e4ab1
--- /dev/null
+++ b/sqlglot/optimizer/optimize_joins.py
@@ -0,0 +1,75 @@
+from sqlglot import exp
+from sqlglot.helper import tsort
+from sqlglot.optimizer.simplify import simplify
+
+
+def optimize_joins(expression):
+ """
+ Removes cross joins if possible and reorder joins based on predicate dependencies.
+ """
+ for select in expression.find_all(exp.Select):
+ references = {}
+ cross_joins = []
+
+ for join in select.args.get("joins", []):
+ name = join.this.alias_or_name
+ tables = other_table_names(join, name)
+
+ if tables:
+ for table in tables:
+ references[table] = references.get(table, []) + [join]
+ else:
+ cross_joins.append((name, join))
+
+ for name, join in cross_joins:
+ for dep in references.get(name, []):
+ on = dep.args["on"]
+ on = on.replace(simplify(on))
+
+ if isinstance(on, exp.Connector):
+ for predicate in on.flatten():
+ if name in exp.column_table_names(predicate):
+ predicate.replace(exp.TRUE)
+ join.on(predicate, copy=False)
+
+ expression = reorder_joins(expression)
+ expression = normalize(expression)
+ return expression
+
+
+def reorder_joins(expression):
+ """
+ Reorder joins by topological sort order based on predicate references.
+ """
+ for from_ in expression.find_all(exp.From):
+ head = from_.expressions[0]
+ parent = from_.parent
+ joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])}
+ dag = {head.alias_or_name: []}
+
+ for name, join in joins.items():
+ dag[name] = other_table_names(join, name)
+
+ parent.set(
+ "joins",
+ [joins[name] for name in tsort(dag) if name != head.alias_or_name],
+ )
+ return expression
+
+
+def normalize(expression):
+ """
+ Remove INNER and OUTER from joins as they are optional.
+ """
+ for join in expression.find_all(exp.Join):
+ if join.kind != "CROSS":
+ join.set("kind", None)
+ return expression
+
+
+def other_table_names(join, exclude):
+ return [
+ name
+ for name in (exp.column_table_names(join.args.get("on") or exp.TRUE))
+ if name != exclude
+ ]
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
new file mode 100644
index 0000000..c03fe3c
--- /dev/null
+++ b/sqlglot/optimizer/optimizer.py
@@ -0,0 +1,43 @@
+from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries
+from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects
+from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
+from sqlglot.optimizer.normalize import normalize
+from sqlglot.optimizer.optimize_joins import optimize_joins
+from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
+from sqlglot.optimizer.pushdown_projections import pushdown_projections
+from sqlglot.optimizer.qualify_columns import qualify_columns
+from sqlglot.optimizer.qualify_tables import qualify_tables
+from sqlglot.optimizer.quote_identities import quote_identities
+from sqlglot.optimizer.unnest_subqueries import unnest_subqueries
+
+
+def optimize(expression, schema=None, db=None, catalog=None):
+ """
+ Rewrite a sqlglot AST into an optimized form.
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ schema (dict|sqlglot.optimizer.Schema): database schema.
+ This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of
+ the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ db (str): specify the default database, as might be set by a `USE DATABASE db` statement
+ catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ expression = expression.copy()
+ expression = qualify_tables(expression, db=db, catalog=catalog)
+ expression = isolate_table_selects(expression)
+ expression = qualify_columns(expression, schema)
+ expression = pushdown_projections(expression)
+ expression = normalize(expression)
+ expression = unnest_subqueries(expression)
+ expression = expand_multi_table_selects(expression)
+ expression = pushdown_predicates(expression)
+ expression = optimize_joins(expression)
+ expression = eliminate_subqueries(expression)
+ expression = quote_identities(expression)
+ return expression
diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py
new file mode 100644
index 0000000..e757322
--- /dev/null
+++ b/sqlglot/optimizer/pushdown_predicates.py
@@ -0,0 +1,176 @@
+from sqlglot import exp
+from sqlglot.optimizer.normalize import normalized
+from sqlglot.optimizer.scope import traverse_scope
+from sqlglot.optimizer.simplify import simplify
+
+
+def pushdown_predicates(expression):
+ """
+ Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> pushdown_predicates(expression).sql()
+ 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ for scope in reversed(traverse_scope(expression)):
+ select = scope.expression
+ where = select.args.get("where")
+ if where:
+ pushdown(where.this, scope.selected_sources)
+
+ # joins should only pushdown into itself, not to other joins
+ # so we limit the selected sources to only itself
+ for join in select.args.get("joins") or []:
+ name = join.this.alias_or_name
+ pushdown(join.args.get("on"), {name: scope.selected_sources[name]})
+
+ return expression
+
+
+def pushdown(condition, sources):
+ if not condition:
+ return
+
+ condition = condition.replace(simplify(condition))
+ cnf_like = normalized(condition) or not normalized(condition, dnf=True)
+
+ predicates = list(
+ condition.flatten()
+ if isinstance(condition, exp.And if cnf_like else exp.Or)
+ else [condition]
+ )
+
+ if cnf_like:
+ pushdown_cnf(predicates, sources)
+ else:
+ pushdown_dnf(predicates, sources)
+
+
+def pushdown_cnf(predicates, scope):
+ """
+ If the predicates are in CNF like form, we can simply replace each block in the parent.
+ """
+ for predicate in predicates:
+ for node in nodes_for_predicate(predicate, scope).values():
+ if isinstance(node, exp.Join):
+ predicate.replace(exp.TRUE)
+ node.on(predicate, copy=False)
+ break
+ if isinstance(node, exp.Select):
+ predicate.replace(exp.TRUE)
+ node.where(replace_aliases(node, predicate), copy=False)
+
+
+def pushdown_dnf(predicates, scope):
+ """
+ If the predicates are in DNF form, we can only push down conditions that are in all blocks.
+ Additionally, we can't remove predicates from their original form.
+ """
+ # find all the tables that can be pushdown too
+ # these are tables that are referenced in all blocks of a DNF
+ # (a.x AND b.x) OR (a.y AND c.y)
+ # only table a can be push down
+ pushdown_tables = set()
+
+ for a in predicates:
+ a_tables = set(exp.column_table_names(a))
+
+ for b in predicates:
+ a_tables &= set(exp.column_table_names(b))
+
+ pushdown_tables.update(a_tables)
+
+ conditions = {}
+
+ # for every pushdown table, find all related conditions in all predicates
+ # combine them with ORS
+ # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z)
+ for table in sorted(pushdown_tables):
+ for predicate in predicates:
+ nodes = nodes_for_predicate(predicate, scope)
+
+ if table not in nodes:
+ continue
+
+ predicate_condition = None
+
+ for column in predicate.find_all(exp.Column):
+ if column.table == table:
+ condition = column.find_ancestor(exp.Condition)
+ predicate_condition = (
+ exp.and_(predicate_condition, condition)
+ if predicate_condition
+ else condition
+ )
+
+ if predicate_condition:
+ conditions[table] = (
+ exp.or_(conditions[table], predicate_condition)
+ if table in conditions
+ else predicate_condition
+ )
+
+ for name, node in nodes.items():
+ if name not in conditions:
+ continue
+
+ predicate = conditions[name]
+
+ if isinstance(node, exp.Join):
+ node.on(predicate, copy=False)
+ elif isinstance(node, exp.Select):
+ node.where(replace_aliases(node, predicate), copy=False)
+
+
+def nodes_for_predicate(predicate, sources):
+ nodes = {}
+ tables = exp.column_table_names(predicate)
+ where_condition = isinstance(
+ predicate.find_ancestor(exp.Join, exp.Where), exp.Where
+ )
+
+ for table in tables:
+ node, source = sources.get(table) or (None, None)
+
+ # if the predicate is in a where statement we can try to push it down
+ # we want to find the root join or from statement
+ if node and where_condition:
+ node = node.find_ancestor(exp.Join, exp.From)
+
+ # a node can reference a CTE which should be push down
+ if isinstance(node, exp.From) and not isinstance(source, exp.Table):
+ node = source.expression
+
+ if isinstance(node, exp.Join):
+ if node.side:
+ return {}
+ nodes[table] = node
+ elif isinstance(node, exp.Select) and len(tables) == 1:
+ if not node.args.get("group"):
+ nodes[table] = node
+ return nodes
+
+
+def replace_aliases(source, predicate):
+ aliases = {}
+
+ for select in source.selects:
+ if isinstance(select, exp.Alias):
+ aliases[select.alias] = select.this
+ else:
+ aliases[select.name] = select
+
+ def _replace_alias(column):
+ if isinstance(column, exp.Column) and column.name in aliases:
+ return aliases[column.name]
+ return column
+
+ return predicate.transform(_replace_alias)
diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py
new file mode 100644
index 0000000..097ce04
--- /dev/null
+++ b/sqlglot/optimizer/pushdown_projections.py
@@ -0,0 +1,85 @@
+from collections import defaultdict
+
+from sqlglot import alias, exp
+from sqlglot.optimizer.scope import Scope, traverse_scope
+
+# Sentinel value that means an outer query selecting ALL columns
+SELECT_ALL = object()
+
+
+def pushdown_projections(expression):
+ """
+ Rewrite sqlglot AST to remove unused columns projections.
+
+ Example:
+ >>> import sqlglot
+ >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y"
+ >>> expression = sqlglot.parse_one(sql)
+ >>> pushdown_projections(expression).sql()
+ 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y'
+
+ Args:
+ expression (sqlglot.Expression): expression to optimize
+ Returns:
+ sqlglot.Expression: optimized expression
+ """
+ # Map of Scope to all columns being selected by outer queries.
+ referenced_columns = defaultdict(set)
+
+ # We build the scope tree (which is traversed in DFS postorder), then iterate
+ # over the result in reverse order. This should ensure that the set of selected
+ # columns for a particular scope are completely build by the time we get to it.
+ for scope in reversed(traverse_scope(expression)):
+ parent_selections = referenced_columns.get(scope, {SELECT_ALL})
+
+ if scope.expression.args.get("distinct"):
+ # We can't remove columns SELECT DISTINCT nor UNION DISTINCT
+ parent_selections = {SELECT_ALL}
+
+ if isinstance(scope.expression, exp.Union):
+ left, right = scope.union
+ referenced_columns[left] = parent_selections
+ referenced_columns[right] = parent_selections
+
+ if isinstance(scope.expression, exp.Select):
+ _remove_unused_selections(scope, parent_selections)
+
+ # Group columns by source name
+ selects = defaultdict(set)
+ for col in scope.columns:
+ table_name = col.table
+ col_name = col.name
+ selects[table_name].add(col_name)
+
+ # Push the selected columns down to the next scope
+ for name, (_, source) in scope.selected_sources.items():
+ if isinstance(source, Scope):
+ columns = selects.get(name) or set()
+ referenced_columns[source].update(columns)
+
+ return expression
+
+
+def _remove_unused_selections(scope, parent_selections):
+ order = scope.expression.args.get("order")
+
+ if order:
+ # Assume columns without a qualified table are references to output columns
+ order_refs = {c.name for c in order.find_all(exp.Column) if not c.table}
+ else:
+ order_refs = set()
+
+ new_selections = []
+ for selection in scope.selects:
+ if (
+ SELECT_ALL in parent_selections
+ or selection.alias_or_name in parent_selections
+ or selection.alias_or_name in order_refs
+ ):
+ new_selections.append(selection)
+
+ # If there are no remaining selections, just select a single constant
+ if not new_selections:
+ new_selections.append(alias("1", "_"))
+
+ scope.expression.set("expressions", new_selections)
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
new file mode 100644
index 0000000..394f49e
--- /dev/null
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -0,0 +1,422 @@
+import itertools
+
+from sqlglot import alias, exp
+from sqlglot.errors import OptimizeError
+from sqlglot.optimizer.schema import ensure_schema
+from sqlglot.optimizer.scope import traverse_scope
+
+SKIP_QUALIFY = (exp.Unnest, exp.Lateral)
+
+
+def qualify_columns(expression, schema):
+ """
+ Rewrite sqlglot AST to have fully qualified columns.
+
+ Example:
+ >>> import sqlglot
+ >>> schema = {"tbl": {"col": "INT"}}
+ >>> expression = sqlglot.parse_one("SELECT col FROM tbl")
+ >>> qualify_columns(expression, schema).sql()
+ 'SELECT tbl.col AS col FROM tbl'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ schema (dict|sqlglot.optimizer.Schema): Database schema
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ schema = ensure_schema(schema)
+
+ for scope in traverse_scope(expression):
+ resolver = _Resolver(scope, schema)
+ _pop_table_column_aliases(scope.ctes)
+ _pop_table_column_aliases(scope.derived_tables)
+ _expand_using(scope, resolver)
+ _expand_group_by(scope, resolver)
+ _expand_order_by(scope)
+ _qualify_columns(scope, resolver)
+ if not isinstance(scope.expression, SKIP_QUALIFY):
+ _expand_stars(scope, resolver)
+ _qualify_outputs(scope)
+ _check_unknown_tables(scope)
+
+ return expression
+
+
+def _pop_table_column_aliases(derived_tables):
+ """
+ Remove table column aliases.
+
+ (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2)
+ """
+ for derived_table in derived_tables:
+ if isinstance(derived_table, SKIP_QUALIFY):
+ continue
+ table_alias = derived_table.args.get("alias")
+ if table_alias:
+ table_alias.args.pop("columns", None)
+
+
+def _expand_using(scope, resolver):
+ joins = list(scope.expression.find_all(exp.Join))
+ names = {join.this.alias for join in joins}
+ ordered = [key for key in scope.selected_sources if key not in names]
+
+ # Mapping of automatically joined column names to source names
+ column_tables = {}
+
+ for join in joins:
+ using = join.args.get("using")
+
+ if not using:
+ continue
+
+ join_table = join.this.alias_or_name
+
+ columns = {}
+
+ for k in scope.selected_sources:
+ if k in ordered:
+ for column in resolver.get_source_columns(k):
+ if column not in columns:
+ columns[column] = k
+
+ ordered.append(join_table)
+ join_columns = resolver.get_source_columns(join_table)
+ conditions = []
+
+ for identifier in using:
+ identifier = identifier.name
+ table = columns.get(identifier)
+
+ if not table or identifier not in join_columns:
+ raise OptimizeError(f"Cannot automatically join: {identifier}")
+
+ conditions.append(
+ exp.condition(
+ exp.EQ(
+ this=exp.column(identifier, table=table),
+ expression=exp.column(identifier, table=join_table),
+ )
+ )
+ )
+
+ tables = column_tables.setdefault(identifier, [])
+ if table not in tables:
+ tables.append(table)
+ if join_table not in tables:
+ tables.append(join_table)
+
+ join.args.pop("using")
+ join.set("on", exp.and_(*conditions))
+
+ if column_tables:
+ for column in scope.columns:
+ if not column.table and column.name in column_tables:
+ tables = column_tables[column.name]
+ coalesce = [exp.column(column.name, table=table) for table in tables]
+ replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:])
+
+ # Ensure selects keep their output name
+ if isinstance(column.parent, exp.Select):
+ replacement = exp.alias_(replacement, alias=column.name)
+
+ scope.replace(column, replacement)
+
+
+def _expand_group_by(scope, resolver):
+ group = scope.expression.args.get("group")
+ if not group:
+ return
+
+ # Replace references to select aliases
+ def transform(node, *_):
+ if isinstance(node, exp.Column) and not node.table:
+ table = resolver.get_table(node.name)
+
+ # Source columns get priority over select aliases
+ if table:
+ node.set("table", exp.to_identifier(table))
+ return node
+
+ selects = {s.alias_or_name: s for s in scope.selects}
+
+ select = selects.get(node.name)
+ if select:
+ scope.clear_cache()
+ if isinstance(select, exp.Alias):
+ select = select.this
+ return select.copy()
+
+ return node
+
+ group.transform(transform, copy=False)
+ group.set("expressions", _expand_positional_references(scope, group.expressions))
+ scope.expression.set("group", group)
+
+
+def _expand_order_by(scope):
+ order = scope.expression.args.get("order")
+ if not order:
+ return
+
+ ordereds = order.expressions
+ for ordered, new_expression in zip(
+ ordereds,
+ _expand_positional_references(scope, (o.this for o in ordereds)),
+ ):
+ ordered.set("this", new_expression)
+
+
+def _expand_positional_references(scope, expressions):
+ new_nodes = []
+ for node in expressions:
+ if node.is_int:
+ try:
+ select = scope.selects[int(node.name) - 1]
+ except IndexError:
+ raise OptimizeError(f"Unknown output column: {node.name}")
+ if isinstance(select, exp.Alias):
+ select = select.this
+ new_nodes.append(select.copy())
+ scope.clear_cache()
+ else:
+ new_nodes.append(node)
+
+ return new_nodes
+
+
+def _qualify_columns(scope, resolver):
+ """Disambiguate columns, ensuring each column specifies a source"""
+ for column in scope.columns:
+ column_table = column.table
+ column_name = column.name
+
+ if (
+ column_table
+ and column_table in scope.sources
+ and column_name not in resolver.get_source_columns(column_table)
+ ):
+ raise OptimizeError(f"Unknown column: {column_name}")
+
+ if not column_table:
+ column_table = resolver.get_table(column_name)
+
+ if not scope.is_subquery and not scope.is_unnest:
+ if column_name not in resolver.all_columns:
+ raise OptimizeError(f"Unknown column: {column_name}")
+
+ if column_table is None:
+ raise OptimizeError(f"Ambiguous column: {column_name}")
+
+ # column_table can be a '' because bigquery unnest has no table alias
+ if column_table:
+ column.set("table", exp.to_identifier(column_table))
+
+
+def _expand_stars(scope, resolver):
+ """Expand stars to lists of column selections"""
+
+ new_selections = []
+ except_columns = {}
+ replace_columns = {}
+
+ for expression in scope.selects:
+ if isinstance(expression, exp.Star):
+ tables = list(scope.selected_sources)
+ _add_except_columns(expression, tables, except_columns)
+ _add_replace_columns(expression, tables, replace_columns)
+ elif isinstance(expression, exp.Column) and isinstance(
+ expression.this, exp.Star
+ ):
+ tables = [expression.table]
+ _add_except_columns(expression.this, tables, except_columns)
+ _add_replace_columns(expression.this, tables, replace_columns)
+ else:
+ new_selections.append(expression)
+ continue
+
+ for table in tables:
+ if table not in scope.sources:
+ raise OptimizeError(f"Unknown table: {table}")
+ columns = resolver.get_source_columns(table)
+ table_id = id(table)
+ for name in columns:
+ if name not in except_columns.get(table_id, set()):
+ alias_ = replace_columns.get(table_id, {}).get(name, name)
+ column = exp.column(name, table)
+ new_selections.append(
+ alias(column, alias_) if alias_ != name else column
+ )
+
+ scope.expression.set("expressions", new_selections)
+
+
+def _add_except_columns(expression, tables, except_columns):
+ except_ = expression.args.get("except")
+
+ if not except_:
+ return
+
+ columns = {e.name for e in except_}
+
+ for table in tables:
+ except_columns[id(table)] = columns
+
+
+def _add_replace_columns(expression, tables, replace_columns):
+ replace = expression.args.get("replace")
+
+ if not replace:
+ return
+
+ columns = {e.this.name: e.alias for e in replace}
+
+ for table in tables:
+ replace_columns[id(table)] = columns
+
+
+def _qualify_outputs(scope):
+ """Ensure all output columns are aliased"""
+ new_selections = []
+
+ for i, (selection, aliased_column) in enumerate(
+ itertools.zip_longest(scope.selects, scope.outer_column_list)
+ ):
+ if isinstance(selection, exp.Column):
+ # convoluted setter because a simple selection.replace(alias) would require a copy
+ alias_ = alias(exp.column(""), alias=selection.name)
+ alias_.set("this", selection)
+ selection = alias_
+ elif not isinstance(selection, exp.Alias):
+ alias_ = alias(exp.column(""), f"_col_{i}")
+ alias_.set("this", selection)
+ selection = alias_
+
+ if aliased_column:
+ selection.set("alias", exp.to_identifier(aliased_column))
+
+ new_selections.append(selection)
+
+ scope.expression.set("expressions", new_selections)
+
+
+def _check_unknown_tables(scope):
+ if (
+ scope.external_columns
+ and not scope.is_unnest
+ and not scope.is_correlated_subquery
+ ):
+ raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}")
+
+
+class _Resolver:
+ """
+ Helper for resolving columns.
+
+ This is a class so we can lazily load some things and easily share them across functions.
+ """
+
+ def __init__(self, scope, schema):
+ self.scope = scope
+ self.schema = schema
+ self._source_columns = None
+ self._unambiguous_columns = None
+ self._all_columns = None
+
+ def get_table(self, column_name):
+ """
+ Get the table for a column name.
+
+ Args:
+ column_name (str)
+ Returns:
+ (str) table name
+ """
+ if self._unambiguous_columns is None:
+ self._unambiguous_columns = self._get_unambiguous_columns(
+ self._get_all_source_columns()
+ )
+ return self._unambiguous_columns.get(column_name)
+
+ @property
+ def all_columns(self):
+ """All available columns of all sources in this scope"""
+ if self._all_columns is None:
+ self._all_columns = set(
+ column
+ for columns in self._get_all_source_columns().values()
+ for column in columns
+ )
+ return self._all_columns
+
+ def get_source_columns(self, name):
+ """Resolve the source columns for a given source `name`"""
+ if name not in self.scope.sources:
+ raise OptimizeError(f"Unknown table: {name}")
+
+ source = self.scope.sources[name]
+
+ # If referencing a table, return the columns from the schema
+ if isinstance(source, exp.Table):
+ try:
+ return self.schema.column_names(source)
+ except Exception as e:
+ raise OptimizeError(str(e)) from e
+
+ # Otherwise, if referencing another scope, return that scope's named selects
+ return source.expression.named_selects
+
+ def _get_all_source_columns(self):
+ if self._source_columns is None:
+ self._source_columns = {
+ k: self.get_source_columns(k) for k in self.scope.selected_sources
+ }
+ return self._source_columns
+
+ def _get_unambiguous_columns(self, source_columns):
+ """
+ Find all the unambiguous columns in sources.
+
+ Args:
+ source_columns (dict): Mapping of names to source columns
+ Returns:
+ dict: Mapping of column name to source name
+ """
+ if not source_columns:
+ return {}
+
+ source_columns = list(source_columns.items())
+
+ first_table, first_columns = source_columns[0]
+ unambiguous_columns = {
+ col: first_table for col in self._find_unique_columns(first_columns)
+ }
+ all_columns = set(unambiguous_columns)
+
+ for table, columns in source_columns[1:]:
+ unique = self._find_unique_columns(columns)
+ ambiguous = set(all_columns).intersection(unique)
+ all_columns.update(columns)
+ for column in ambiguous:
+ unambiguous_columns.pop(column, None)
+ for column in unique.difference(ambiguous):
+ unambiguous_columns[column] = table
+
+ return unambiguous_columns
+
+ @staticmethod
+ def _find_unique_columns(columns):
+ """
+ Find the unique columns in a list of columns.
+
+ Example:
+ >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"]))
+ ['a', 'c']
+
+ This is necessary because duplicate column names are ambiguous.
+ """
+ counts = {}
+ for column in columns:
+ counts[column] = counts.get(column, 0) + 1
+ return {column for column, count in counts.items() if count == 1}
diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py
new file mode 100644
index 0000000..9f8b9f5
--- /dev/null
+++ b/sqlglot/optimizer/qualify_tables.py
@@ -0,0 +1,54 @@
+import itertools
+
+from sqlglot import alias, exp
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def qualify_tables(expression, db=None, catalog=None):
+ """
+ Rewrite sqlglot AST to have fully qualified tables.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl")
+ >>> qualify_tables(expression, db="db").sql()
+ 'SELECT 1 FROM db.tbl AS tbl'
+
+ Args:
+ expression (sqlglot.Expression): expression to qualify
+ db (str): Database name
+ catalog (str): Catalog name
+ Returns:
+ sqlglot.Expression: qualified expression
+ """
+ sequence = itertools.count()
+
+ for scope in traverse_scope(expression):
+ for derived_table in scope.ctes + scope.derived_tables:
+ if not derived_table.args.get("alias"):
+ alias_ = f"_q_{next(sequence)}"
+ derived_table.set(
+ "alias", exp.TableAlias(this=exp.to_identifier(alias_))
+ )
+ scope.rename_source(None, alias_)
+
+ for source in scope.sources.values():
+ if isinstance(source, exp.Table):
+ identifier = isinstance(source.this, exp.Identifier)
+
+ if identifier:
+ if not source.args.get("db"):
+ source.set("db", exp.to_identifier(db))
+ if not source.args.get("catalog"):
+ source.set("catalog", exp.to_identifier(catalog))
+
+ if not isinstance(source.parent, exp.Alias):
+ source.replace(
+ alias(
+ source.copy(),
+ source.this if identifier else f"_q_{next(sequence)}",
+ table=True,
+ )
+ )
+
+ return expression
diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py
new file mode 100644
index 0000000..17623cc
--- /dev/null
+++ b/sqlglot/optimizer/quote_identities.py
@@ -0,0 +1,25 @@
+from sqlglot import exp
+
+
+def quote_identities(expression):
+ """
+ Rewrite sqlglot AST to ensure all identities are quoted.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x")
+ >>> quote_identities(expression).sql()
+ 'SELECT "x"."a" AS "a" FROM "db"."x"'
+
+ Args:
+ expression (sqlglot.Expression): expression to quote
+ Returns:
+ sqlglot.Expression: quoted expression
+ """
+
+ def qualify(node):
+ if isinstance(node, exp.Identifier):
+ node.set("quoted", True)
+ return node
+
+ return expression.transform(qualify, copy=False)
diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py
new file mode 100644
index 0000000..9968108
--- /dev/null
+++ b/sqlglot/optimizer/schema.py
@@ -0,0 +1,129 @@
+import abc
+
+from sqlglot import exp
+from sqlglot.errors import OptimizeError
+from sqlglot.helper import csv_reader
+
+
+class Schema(abc.ABC):
+ """Abstract base class for database schemas"""
+
+ @abc.abstractmethod
+ def column_names(self, table):
+ """
+ Get the column names for a table.
+
+ Args:
+ table (sqlglot.expressions.Table): Table expression instance
+ Returns:
+ list[str]: list of column names
+ """
+
+
+class MappingSchema(Schema):
+ """
+ Schema based on a nested mapping.
+
+ Args:
+ schema (dict): Mapping in one of the following forms:
+ 1. {table: {col: type}}
+ 2. {db: {table: {col: type}}}
+ 3. {catalog: {db: {table: {col: type}}}}
+ """
+
+ def __init__(self, schema):
+ self.schema = schema
+
+ depth = _dict_depth(schema)
+
+ if not depth: # {}
+ self.supported_table_args = []
+ elif depth == 2: # {table: {col: type}}
+ self.supported_table_args = ("this",)
+ elif depth == 3: # {db: {table: {col: type}}}
+ self.supported_table_args = ("db", "this")
+ elif depth == 4: # {catalog: {db: {table: {col: type}}}}
+ self.supported_table_args = ("catalog", "db", "this")
+ else:
+ raise OptimizeError(f"Invalid schema shape. Depth: {depth}")
+
+ self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args)
+
+ def column_names(self, table):
+ if not isinstance(table.this, exp.Identifier):
+ return fs_get(table)
+
+ args = tuple(table.text(p) for p in self.supported_table_args)
+
+ for forbidden in self.forbidden_args:
+ if table.text(forbidden):
+ raise ValueError(
+ f"Schema doesn't support {forbidden}. Received: {table.sql()}"
+ )
+ return list(_nested_get(self.schema, *zip(self.supported_table_args, args)))
+
+
+def ensure_schema(schema):
+ if isinstance(schema, Schema):
+ return schema
+
+ return MappingSchema(schema)
+
+
+def fs_get(table):
+ name = table.this.name.upper()
+
+ if name.upper() == "READ_CSV":
+ with csv_reader(table) as reader:
+ return next(reader)
+
+ raise ValueError(f"Cannot read schema for {table}")
+
+
+def _nested_get(d, *path):
+ """
+ Get a value for a nested dictionary.
+
+ Args:
+ d (dict): dictionary
+ *path (tuple[str, str]): tuples of (name, key)
+ `key` is the key in the dictionary to get.
+ `name` is a string to use in the error if `key` isn't found.
+ """
+ for name, key in path:
+ d = d.get(key)
+ if d is None:
+ name = "table" if name == "this" else name
+ raise ValueError(f"Unknown {name}")
+ return d
+
+
+def _dict_depth(d):
+ """
+ Get the nesting depth of a dictionary.
+
+ For example:
+ >>> _dict_depth(None)
+ 0
+ >>> _dict_depth({})
+ 1
+ >>> _dict_depth({"a": "b"})
+ 1
+ >>> _dict_depth({"a": {}})
+ 2
+ >>> _dict_depth({"a": {"b": {}}})
+ 3
+
+ Args:
+ d (dict): dictionary
+ Returns:
+ int: depth
+ """
+ try:
+ return 1 + _dict_depth(next(iter(d.values())))
+ except AttributeError:
+ # d doesn't have attribute "values"
+ return 0
+ except StopIteration:
+ # d.values() returns an empty sequence
+ return 1
diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py
new file mode 100644
index 0000000..f6f59e8
--- /dev/null
+++ b/sqlglot/optimizer/scope.py
@@ -0,0 +1,438 @@
+from copy import copy
+from enum import Enum, auto
+
+from sqlglot import exp
+from sqlglot.errors import OptimizeError
+
+
+class ScopeType(Enum):
+ ROOT = auto()
+ SUBQUERY = auto()
+ DERIVED_TABLE = auto()
+ CTE = auto()
+ UNION = auto()
+ UNNEST = auto()
+
+
+class Scope:
+ """
+ Selection scope.
+
+ Attributes:
+ expression (exp.Select|exp.Union): Root expression of this scope
+ sources (dict[str, exp.Table|Scope]): Mapping of source name to either
+ a Table expression or another Scope instance. For example:
+ SELECT * FROM x {"x": Table(this="x")}
+ SELECT * FROM x AS y {"y": Table(this="x")}
+ SELECT * FROM (SELECT ...) AS y {"y": Scope(...)}
+ outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
+ defines a column list of it's alias of this scope, this is that list of columns.
+ For example:
+ SELECT * FROM (SELECT ...) AS y(col1, col2)
+ The inner query would have `["col1", "col2"]` for its `outer_column_list`
+ parent (Scope): Parent scope
+ scope_type (ScopeType): Type of this scope, relative to it's parent
+ subquery_scopes (list[Scope]): List of all child scopes for subqueries.
+ This does not include derived tables or CTEs.
+ union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be
+ a tuple of the left and right child scopes.
+ """
+
+ def __init__(
+ self,
+ expression,
+ sources=None,
+ outer_column_list=None,
+ parent=None,
+ scope_type=ScopeType.ROOT,
+ ):
+ self.expression = expression
+ self.sources = sources or {}
+ self.outer_column_list = outer_column_list or []
+ self.parent = parent
+ self.scope_type = scope_type
+ self.subquery_scopes = []
+ self.union = None
+ self.clear_cache()
+
+ def clear_cache(self):
+ self._collected = False
+ self._raw_columns = None
+ self._derived_tables = None
+ self._tables = None
+ self._ctes = None
+ self._subqueries = None
+ self._selected_sources = None
+ self._columns = None
+ self._external_columns = None
+
+ def branch(self, expression, scope_type, add_sources=None, **kwargs):
+ """Branch from the current scope to a new, inner scope"""
+ sources = copy(self.sources)
+ if add_sources:
+ sources.update(add_sources)
+ return Scope(
+ expression=expression.unnest(),
+ sources=sources,
+ parent=self,
+ scope_type=scope_type,
+ **kwargs,
+ )
+
+ def _collect(self):
+ self._tables = []
+ self._ctes = []
+ self._subqueries = []
+ self._derived_tables = []
+ self._raw_columns = []
+
+ # We'll use this variable to pass state into the dfs generator.
+ # Whenever we set it to True, we exclude a subtree from traversal.
+ prune = False
+
+ for node, parent, _ in self.expression.dfs(prune=lambda *_: prune):
+ prune = False
+
+ if node is self.expression:
+ continue
+ if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star):
+ self._raw_columns.append(node)
+ elif isinstance(node, exp.Table):
+ self._tables.append(node)
+ elif isinstance(node, (exp.Unnest, exp.Lateral)):
+ self._derived_tables.append(node)
+ elif isinstance(node, exp.CTE):
+ self._ctes.append(node)
+ prune = True
+ elif isinstance(node, exp.Subquery) and isinstance(
+ parent, (exp.From, exp.Join)
+ ):
+ self._derived_tables.append(node)
+ prune = True
+ elif isinstance(node, exp.Subqueryable):
+ self._subqueries.append(node)
+ prune = True
+
+ self._collected = True
+
+ def _ensure_collected(self):
+ if not self._collected:
+ self._collect()
+
+ def replace(self, old, new):
+ """
+ Replace `old` with `new`.
+
+ This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date.
+
+ Args:
+ old (exp.Expression): old node
+ new (exp.Expression): new node
+ """
+ old.replace(new)
+ self.clear_cache()
+
+ @property
+ def tables(self):
+ """
+ List of tables in this scope.
+
+ Returns:
+ list[exp.Table]: tables
+ """
+ self._ensure_collected()
+ return self._tables
+
+ @property
+ def ctes(self):
+ """
+ List of CTEs in this scope.
+
+ Returns:
+ list[exp.CTE]: ctes
+ """
+ self._ensure_collected()
+ return self._ctes
+
+ @property
+ def derived_tables(self):
+ """
+ List of derived tables in this scope.
+
+ For example:
+ SELECT * FROM (SELECT ...) <- that's a derived table
+
+ Returns:
+ list[exp.Subquery]: derived tables
+ """
+ self._ensure_collected()
+ return self._derived_tables
+
+ @property
+ def subqueries(self):
+ """
+ List of subqueries in this scope.
+
+ For example:
+ SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery
+
+ Returns:
+ list[exp.Subqueryable]: subqueries
+ """
+ self._ensure_collected()
+ return self._subqueries
+
+ @property
+ def columns(self):
+ """
+ List of columns in this scope.
+
+ Returns:
+ list[exp.Column]: Column instances in this scope, plus any
+ Columns that reference this scope from correlated subqueries.
+ """
+ if self._columns is None:
+ self._ensure_collected()
+ columns = self._raw_columns
+
+ external_columns = [
+ column
+ for scope in self.subquery_scopes
+ for column in scope.external_columns
+ ]
+
+ named_outputs = {e.alias_or_name for e in self.expression.expressions}
+
+ self._columns = [
+ c
+ for c in columns + external_columns
+ if not (
+ c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs
+ )
+ ]
+ return self._columns
+
+ @property
+ def selected_sources(self):
+ """
+ Mapping of nodes and sources that are actually selected from in this scope.
+
+ That is, all tables in a schema are selectable at any point. But a
+ table only becomes a selected source if it's included in a FROM or JOIN clause.
+
+ Returns:
+ dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes
+ """
+ if self._selected_sources is None:
+ referenced_names = []
+
+ for table in self.tables:
+ referenced_names.append(
+ (
+ table.parent.alias
+ if isinstance(table.parent, exp.Alias)
+ else table.name,
+ table,
+ )
+ )
+ for derived_table in self.derived_tables:
+ referenced_names.append((derived_table.alias, derived_table.unnest()))
+
+ result = {}
+
+ for name, node in referenced_names:
+ if name in self.sources:
+ result[name] = (node, self.sources[name])
+
+ self._selected_sources = result
+ return self._selected_sources
+
+ @property
+ def selects(self):
+ """
+ Select expressions of this scope.
+
+ For example, for the following expression:
+ SELECT 1 as a, 2 as b FROM x
+
+ The outputs are the "1 as a" and "2 as b" expressions.
+
+ Returns:
+ list[exp.Expression]: expressions
+ """
+ if isinstance(self.expression, exp.Union):
+ return []
+ return self.expression.selects
+
+ @property
+ def external_columns(self):
+ """
+ Columns that appear to reference sources in outer scopes.
+
+ Returns:
+ list[exp.Column]: Column instances that don't reference
+ sources in the current scope.
+ """
+ if self._external_columns is None:
+ self._external_columns = [
+ c for c in self.columns if c.table not in self.selected_sources
+ ]
+ return self._external_columns
+
+ def source_columns(self, source_name):
+ """
+ Get all columns in the current scope for a particular source.
+
+ Args:
+ source_name (str): Name of the source
+ Returns:
+ list[exp.Column]: Column instances that reference `source_name`
+ """
+ return [column for column in self.columns if column.table == source_name]
+
+ @property
+ def is_subquery(self):
+ """Determine if this scope is a subquery"""
+ return self.scope_type == ScopeType.SUBQUERY
+
+ @property
+ def is_unnest(self):
+ """Determine if this scope is an unnest"""
+ return self.scope_type == ScopeType.UNNEST
+
+ @property
+ def is_correlated_subquery(self):
+ """Determine if this scope is a correlated subquery"""
+ return bool(self.is_subquery and self.external_columns)
+
+ def rename_source(self, old_name, new_name):
+ """Rename a source in this scope"""
+ columns = self.sources.pop(old_name or "", [])
+ self.sources[new_name] = columns
+
+
+def traverse_scope(expression):
+ """
+ Traverse an expression by it's "scopes".
+
+ "Scope" represents the current context of a Select statement.
+
+ This is helpful for optimizing queries, where we need more information than
+ the expression tree itself. For example, we might care about the source
+ names within a subquery. Returns a list because a generator could result in
+ incomplete properties which is confusing.
+
+ Examples:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y")
+ >>> scopes = traverse_scope(expression)
+ >>> scopes[0].expression.sql(), list(scopes[0].sources)
+ ('SELECT a FROM x', ['x'])
+ >>> scopes[1].expression.sql(), list(scopes[1].sources)
+ ('SELECT a FROM (SELECT a FROM x) AS y', ['y'])
+
+ Args:
+ expression (exp.Expression): expression to traverse
+ Returns:
+ List[Scope]: scope instances
+ """
+ return list(_traverse_scope(Scope(expression)))
+
+
+def _traverse_scope(scope):
+ if isinstance(scope.expression, exp.Select):
+ yield from _traverse_select(scope)
+ elif isinstance(scope.expression, exp.Union):
+ yield from _traverse_union(scope)
+ elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)):
+ pass
+ elif isinstance(scope.expression, exp.Subquery):
+ yield from _traverse_subqueries(scope)
+ else:
+ raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}")
+ yield scope
+
+
+def _traverse_select(scope):
+ yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE)
+ yield from _traverse_subqueries(scope)
+ yield from _traverse_derived_tables(
+ scope.derived_tables, scope, ScopeType.DERIVED_TABLE
+ )
+ _add_table_sources(scope)
+
+
+def _traverse_union(scope):
+ yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE)
+
+ # The last scope to be yield should be the top most scope
+ left = None
+ for left in _traverse_scope(
+ scope.branch(scope.expression.left, scope_type=ScopeType.UNION)
+ ):
+ yield left
+
+ right = None
+ for right in _traverse_scope(
+ scope.branch(scope.expression.right, scope_type=ScopeType.UNION)
+ ):
+ yield right
+
+ scope.union = (left, right)
+
+
+def _traverse_derived_tables(derived_tables, scope, scope_type):
+ sources = {}
+
+ for derived_table in derived_tables:
+ for child_scope in _traverse_scope(
+ scope.branch(
+ derived_table
+ if isinstance(derived_table, (exp.Unnest, exp.Lateral))
+ else derived_table.this,
+ add_sources=sources if scope_type == ScopeType.CTE else None,
+ outer_column_list=derived_table.alias_column_names,
+ scope_type=ScopeType.UNNEST
+ if isinstance(derived_table, exp.Unnest)
+ else scope_type,
+ )
+ ):
+ yield child_scope
+ # Tables without aliases will be set as ""
+ # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything.
+ # Until then, this means that only a single, unaliased derived table is allowed (rather,
+ # the latest one wins.
+ sources[derived_table.alias] = child_scope
+ scope.sources.update(sources)
+
+
+def _add_table_sources(scope):
+ sources = {}
+ for table in scope.tables:
+ table_name = table.name
+
+ if isinstance(table.parent, exp.Alias):
+ source_name = table.parent.alias
+ else:
+ source_name = table_name
+
+ if table_name in scope.sources:
+ # This is a reference to a parent source (e.g. a CTE), not an actual table.
+ scope.sources[source_name] = scope.sources[table_name]
+ elif source_name in scope.sources:
+ raise OptimizeError(f"Duplicate table name: {source_name}")
+ else:
+ sources[source_name] = table
+
+ scope.sources.update(sources)
+
+
+def _traverse_subqueries(scope):
+ for subquery in scope.subqueries:
+ top = None
+ for child_scope in _traverse_scope(
+ scope.branch(subquery, scope_type=ScopeType.SUBQUERY)
+ ):
+ yield child_scope
+ top = child_scope
+ scope.subquery_scopes.append(top)
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
new file mode 100644
index 0000000..6771153
--- /dev/null
+++ b/sqlglot/optimizer/simplify.py
@@ -0,0 +1,383 @@
+import datetime
+import functools
+import itertools
+from collections import deque
+from decimal import Decimal
+
+from sqlglot import exp
+from sqlglot.expressions import FALSE, NULL, TRUE
+from sqlglot.generator import Generator
+from sqlglot.helper import while_changing
+
+GENERATOR = Generator(normalize=True, identify=True)
+
+
+def simplify(expression):
+ """
+ Rewrite sqlglot AST to simplify expressions.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("TRUE AND TRUE")
+ >>> simplify(expression).sql()
+ 'TRUE'
+
+ Args:
+ expression (sqlglot.Expression): expression to simplify
+ Returns:
+ sqlglot.Expression: simplified expression
+ """
+
+ def _simplify(expression, root=True):
+ node = expression
+ node = uniq_sort(node)
+ node = absorb_and_eliminate(node)
+ exp.replace_children(node, lambda e: _simplify(e, False))
+ node = simplify_not(node)
+ node = flatten(node)
+ node = simplify_connectors(node)
+ node = remove_compliments(node)
+ node.parent = expression.parent
+ node = simplify_literals(node)
+ node = simplify_parens(node)
+ if root:
+ expression.replace(node)
+ return node
+
+ expression = while_changing(expression, _simplify)
+ remove_where_true(expression)
+ return expression
+
+
+def simplify_not(expression):
+ """
+ Demorgan's Law
+ NOT (x OR y) -> NOT x AND NOT y
+ NOT (x AND y) -> NOT x OR NOT y
+ """
+ if isinstance(expression, exp.Not):
+ if isinstance(expression.this, exp.Paren):
+ condition = expression.this.unnest()
+ if isinstance(condition, exp.And):
+ return exp.or_(exp.not_(condition.left), exp.not_(condition.right))
+ if isinstance(condition, exp.Or):
+ return exp.and_(exp.not_(condition.left), exp.not_(condition.right))
+ if always_true(expression.this):
+ return FALSE
+ if expression.this == FALSE:
+ return TRUE
+ if isinstance(expression.this, exp.Not):
+ # double negation
+ # NOT NOT x -> x
+ return expression.this.this
+ return expression
+
+
+def flatten(expression):
+ """
+ A AND (B AND C) -> A AND B AND C
+ A OR (B OR C) -> A OR B OR C
+ """
+ if isinstance(expression, exp.Connector):
+ for node in expression.args.values():
+ child = node.unnest()
+ if isinstance(child, expression.__class__):
+ node.replace(child)
+ return expression
+
+
+def simplify_connectors(expression):
+ if isinstance(expression, exp.Connector):
+ left = expression.left
+ right = expression.right
+
+ if left == right:
+ return left
+
+ if isinstance(expression, exp.And):
+ if NULL in (left, right):
+ return NULL
+ if FALSE in (left, right):
+ return FALSE
+ if always_true(left) and always_true(right):
+ return TRUE
+ if always_true(left):
+ return right
+ if always_true(right):
+ return left
+ elif isinstance(expression, exp.Or):
+ if always_true(left) or always_true(right):
+ return TRUE
+ if left == FALSE and right == FALSE:
+ return FALSE
+ if (
+ (left == NULL and right == NULL)
+ or (left == NULL and right == FALSE)
+ or (left == FALSE and right == NULL)
+ ):
+ return NULL
+ if left == FALSE:
+ return right
+ if right == FALSE:
+ return left
+ return expression
+
+
+def remove_compliments(expression):
+ """
+ Removing compliments.
+
+ A AND NOT A -> FALSE
+ A OR NOT A -> TRUE
+ """
+ if isinstance(expression, exp.Connector):
+ compliment = FALSE if isinstance(expression, exp.And) else TRUE
+
+ for a, b in itertools.permutations(expression.flatten(), 2):
+ if is_complement(a, b):
+ return compliment
+ return expression
+
+
+def uniq_sort(expression):
+ """
+ Uniq and sort a connector.
+
+ C AND A AND B AND B -> A AND B AND C
+ """
+ if isinstance(expression, exp.Connector):
+ result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
+ flattened = tuple(expression.flatten())
+ deduped = {GENERATOR.generate(e): e for e in flattened}
+ arr = tuple(deduped.items())
+
+ # check if the operands are already sorted, if not sort them
+ # A AND C AND B -> A AND B AND C
+ for i, (sql, e) in enumerate(arr[1:]):
+ if sql < arr[i][0]:
+ expression = result_func(*(deduped[sql] for sql in sorted(deduped)))
+ break
+ else:
+ # we didn't have to sort but maybe we need to dedup
+ if len(deduped) < len(flattened):
+ expression = result_func(*deduped.values())
+
+ return expression
+
+
+def absorb_and_eliminate(expression):
+ """
+ absorption:
+ A AND (A OR B) -> A
+ A OR (A AND B) -> A
+ A AND (NOT A OR B) -> A AND B
+ A OR (NOT A AND B) -> A OR B
+ elimination:
+ (A AND B) OR (A AND NOT B) -> A
+ (A OR B) AND (A OR NOT B) -> A
+ """
+ if isinstance(expression, exp.Connector):
+ kind = exp.Or if isinstance(expression, exp.And) else exp.And
+
+ for a, b in itertools.permutations(expression.flatten(), 2):
+ if isinstance(a, kind):
+ aa, ab = a.unnest_operands()
+
+ # absorb
+ if is_complement(b, aa):
+ aa.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ elif is_complement(b, ab):
+ ab.replace(exp.TRUE if kind == exp.And else exp.FALSE)
+ elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(
+ a.flatten()
+ ):
+ a.replace(exp.FALSE if kind == exp.And else exp.TRUE)
+ elif isinstance(b, kind):
+ # eliminate
+ rhs = b.unnest_operands()
+ ba, bb = rhs
+
+ if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
+ a.replace(aa)
+ b.replace(aa)
+ elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
+ a.replace(ab)
+ b.replace(ab)
+
+ return expression
+
+
+def simplify_literals(expression):
+ if isinstance(expression, exp.Binary):
+ operands = []
+ queue = deque(expression.flatten(unnest=False))
+ size = len(queue)
+
+ while queue:
+ a = queue.popleft()
+
+ for b in queue:
+ result = _simplify_binary(expression, a, b)
+
+ if result:
+ queue.remove(b)
+ queue.append(result)
+ break
+ else:
+ operands.append(a)
+
+ if len(operands) < size:
+ return functools.reduce(
+ lambda a, b: expression.__class__(this=a, expression=b), operands
+ )
+ elif isinstance(expression, exp.Neg):
+ this = expression.this
+ if this.is_number:
+ value = this.name
+ if value[0] == "-":
+ return exp.Literal.number(value[1:])
+ return exp.Literal.number(f"-{value}")
+
+ return expression
+
+
+def _simplify_binary(expression, a, b):
+ if isinstance(expression, exp.Is):
+ if isinstance(b, exp.Not):
+ c = b.this
+ not_ = True
+ else:
+ c = b
+ not_ = False
+
+ if c == NULL:
+ if isinstance(a, exp.Literal):
+ return TRUE if not_ else FALSE
+ if a == NULL:
+ return FALSE if not_ else TRUE
+ elif NULL in (a, b):
+ return NULL
+
+ if isinstance(expression, exp.EQ) and a == b:
+ return TRUE
+
+ if a.is_number and b.is_number:
+ a = int(a.name) if a.is_int else Decimal(a.name)
+ b = int(b.name) if b.is_int else Decimal(b.name)
+
+ if isinstance(expression, exp.Add):
+ return exp.Literal.number(a + b)
+ if isinstance(expression, exp.Sub):
+ return exp.Literal.number(a - b)
+ if isinstance(expression, exp.Mul):
+ return exp.Literal.number(a * b)
+ if isinstance(expression, exp.Div):
+ if isinstance(a, int) and isinstance(b, int):
+ return exp.Literal.number(a // b)
+ return exp.Literal.number(a / b)
+
+ boolean = eval_boolean(expression, a, b)
+
+ if boolean:
+ return boolean
+ elif a.is_string and b.is_string:
+ boolean = eval_boolean(expression, a, b)
+
+ if boolean:
+ return boolean
+ elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval):
+ a, b = extract_date(a), extract_interval(b)
+ if b:
+ if isinstance(expression, exp.Add):
+ return date_literal(a + b)
+ if isinstance(expression, exp.Sub):
+ return date_literal(a - b)
+ elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast):
+ a, b = extract_interval(a), extract_date(b)
+ # you cannot subtract a date from an interval
+ if a and isinstance(expression, exp.Add):
+ return date_literal(a + b)
+
+ return None
+
+
+def simplify_parens(expression):
+ if (
+ isinstance(expression, exp.Paren)
+ and not isinstance(expression.this, exp.Select)
+ and (
+ not isinstance(expression.parent, (exp.Condition, exp.Binary))
+ or isinstance(expression.this, (exp.Is, exp.Like))
+ or not isinstance(expression.this, exp.Binary)
+ )
+ ):
+ return expression.this
+ return expression
+
+
+def remove_where_true(expression):
+ for where in expression.find_all(exp.Where):
+ if always_true(where.this):
+ where.parent.set("where", None)
+ for join in expression.find_all(exp.Join):
+ if always_true(join.args.get("on")):
+ join.set("kind", "CROSS")
+ join.set("on", None)
+
+
+def always_true(expression):
+ return expression == TRUE or isinstance(expression, exp.Literal)
+
+
+def is_complement(a, b):
+ return isinstance(b, exp.Not) and b.this == a
+
+
+def eval_boolean(expression, a, b):
+ if isinstance(expression, (exp.EQ, exp.Is)):
+ return boolean_literal(a == b)
+ if isinstance(expression, exp.NEQ):
+ return boolean_literal(a != b)
+ if isinstance(expression, exp.GT):
+ return boolean_literal(a > b)
+ if isinstance(expression, exp.GTE):
+ return boolean_literal(a >= b)
+ if isinstance(expression, exp.LT):
+ return boolean_literal(a < b)
+ if isinstance(expression, exp.LTE):
+ return boolean_literal(a <= b)
+ return None
+
+
+def extract_date(cast):
+ if cast.args["to"].this == exp.DataType.Type.DATE:
+ return datetime.date.fromisoformat(cast.name)
+ return None
+
+
+def extract_interval(interval):
+ try:
+ from dateutil.relativedelta import relativedelta
+ except ModuleNotFoundError:
+ return None
+
+ n = int(interval.name)
+ unit = interval.text("unit").lower()
+
+ if unit == "year":
+ return relativedelta(years=n)
+ if unit == "month":
+ return relativedelta(months=n)
+ if unit == "week":
+ return relativedelta(weeks=n)
+ if unit == "day":
+ return relativedelta(days=n)
+ return None
+
+
+def date_literal(date):
+ return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE"))
+
+
+def boolean_literal(condition):
+ return TRUE if condition else FALSE
diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py
new file mode 100644
index 0000000..55c81c5
--- /dev/null
+++ b/sqlglot/optimizer/unnest_subqueries.py
@@ -0,0 +1,220 @@
+import itertools
+
+from sqlglot import exp
+from sqlglot.optimizer.scope import traverse_scope
+
+
+def unnest_subqueries(expression):
+ """
+ Rewrite sqlglot AST to convert some predicates with subqueries into joins.
+
+ Convert the subquery into a group by so it is not a many to many left join.
+ Unnesting can only occur if the subquery does not have LIMIT or OFFSET.
+ Unnesting non correlated subqueries only happens on IN statements or = ANY statements.
+
+ Example:
+ >>> import sqlglot
+ >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ")
+ >>> unnest_subqueries(expression).sql()
+ 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\
+ AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)'
+
+ Args:
+ expression (sqlglot.Expression): expression to unnest
+ Returns:
+ sqlglot.Expression: unnested expression
+ """
+ sequence = itertools.count()
+
+ for scope in traverse_scope(expression):
+ select = scope.expression
+ parent = select.parent_select
+ if scope.external_columns:
+ decorrelate(select, parent, scope.external_columns, sequence)
+ else:
+ unnest(select, parent, sequence)
+
+ return expression
+
+
+def unnest(select, parent_select, sequence):
+ predicate = select.find_ancestor(exp.In, exp.Any)
+
+ if not predicate or parent_select is not predicate.parent_select:
+ return
+
+ if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset):
+ return
+
+ if isinstance(predicate, exp.Any):
+ predicate = predicate.find_ancestor(exp.EQ)
+
+ if not predicate or parent_select is not predicate.parent_select:
+ return
+
+ column = _other_operand(predicate)
+ value = select.selects[0]
+ alias = _alias(sequence)
+
+ on = exp.condition(f'{column} = "{alias}"."{value.alias}"')
+ _replace(predicate, f"NOT {on.right} IS NULL")
+
+ parent_select.join(
+ select.group_by(value.this, copy=False),
+ on=on,
+ join_type="LEFT",
+ join_alias=alias,
+ copy=False,
+ )
+
+
+def decorrelate(select, parent_select, external_columns, sequence):
+ where = select.args.get("where")
+
+ if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset):
+ return
+
+ table_alias = _alias(sequence)
+ keys = []
+
+ # for all external columns in the where statement,
+ # split out the relevant data to convert it into a join
+ for column in external_columns:
+ if column.find_ancestor(exp.Where) is not where:
+ return
+
+ predicate = column.find_ancestor(exp.Predicate)
+
+ if not predicate or predicate.find_ancestor(exp.Where) is not where:
+ return
+
+ if isinstance(predicate, exp.Binary):
+ key = (
+ predicate.right
+ if any(node is column for node, *_ in predicate.left.walk())
+ else predicate.left
+ )
+ else:
+ return
+
+ keys.append((key, column, predicate))
+
+ if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys):
+ return
+
+ value = select.selects[0]
+ key_aliases = {}
+ group_by = []
+
+ for key, _, predicate in keys:
+ # if we filter on the value of the subquery, it needs to be unique
+ if key == value.this:
+ key_aliases[key] = value.alias
+ group_by.append(key)
+ else:
+ if key not in key_aliases:
+ key_aliases[key] = _alias(sequence)
+ # all predicates that are equalities must also be in the unique
+ # so that we don't do a many to many join
+ if isinstance(predicate, exp.EQ) and key not in group_by:
+ group_by.append(key)
+
+ parent_predicate = select.find_ancestor(exp.Predicate)
+
+ # if the value of the subquery is not an agg or a key, we need to collect it into an array
+ # so that it can be grouped
+ if not value.find(exp.AggFunc) and value.this not in group_by:
+ select.select(
+ f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False
+ )
+
+ # exists queries should not have any selects as it only checks if there are any rows
+ # all selects will be added by the optimizer and only used for join keys
+ if isinstance(parent_predicate, exp.Exists):
+ select.args["expressions"] = []
+
+ for key, alias in key_aliases.items():
+ if key in group_by:
+ # add all keys to the projections of the subquery
+ # so that we can use it as a join key
+ if isinstance(parent_predicate, exp.Exists) or key != value.this:
+ select.select(f"{key} AS {alias}", copy=False)
+ else:
+ select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False)
+
+ alias = exp.column(value.alias, table_alias)
+ other = _other_operand(parent_predicate)
+
+ if isinstance(parent_predicate, exp.Exists):
+ if value.this in group_by:
+ parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
+ else:
+ parent_predicate = _replace(parent_predicate, "TRUE")
+ elif isinstance(parent_predicate, exp.All):
+ parent_predicate = _replace(
+ parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
+ )
+ elif isinstance(parent_predicate, exp.Any):
+ if value.this in group_by:
+ parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
+ else:
+ parent_predicate = _replace(
+ parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})"
+ )
+ elif isinstance(parent_predicate, exp.In):
+ if value.this in group_by:
+ parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
+ else:
+ parent_predicate = _replace(
+ parent_predicate,
+ f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
+ )
+ else:
+ select.parent.replace(alias)
+
+ for key, column, predicate in keys:
+ predicate.replace(exp.TRUE)
+ nested = exp.column(key_aliases[key], table_alias)
+
+ if key in group_by:
+ key.replace(nested)
+ parent_predicate = _replace(
+ parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)"
+ )
+ elif isinstance(predicate, exp.EQ):
+ parent_predicate = _replace(
+ parent_predicate,
+ f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))",
+ )
+ else:
+ key.replace(exp.to_identifier("_x"))
+ parent_predicate = _replace(
+ parent_predicate,
+ f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))',
+ )
+
+ parent_select.join(
+ select.group_by(*group_by, copy=False),
+ on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)],
+ join_type="LEFT",
+ join_alias=table_alias,
+ copy=False,
+ )
+
+
+def _alias(sequence):
+ return f"_u_{next(sequence)}"
+
+
+def _replace(expression, condition):
+ return expression.replace(exp.condition(condition))
+
+
+def _other_operand(expression):
+ if isinstance(expression, exp.In):
+ return expression.this
+
+ if isinstance(expression, exp.Binary):
+ return expression.right if expression.arg_key == "this" else expression.left
+
+ return None
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
new file mode 100644
index 0000000..9396c50
--- /dev/null
+++ b/sqlglot/parser.py
@@ -0,0 +1,2190 @@
+import logging
+
+from sqlglot import exp
+from sqlglot.errors import ErrorLevel, ParseError, concat_errors
+from sqlglot.helper import apply_index_offset, ensure_list, list_get
+from sqlglot.tokens import Token, Tokenizer, TokenType
+
+logger = logging.getLogger("sqlglot")
+
+
+class Parser:
+ """
+ Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer`
+ and produces a parsed syntax tree.
+
+ Args
+ error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE.
+ error_message_context (int): determines the amount of context to capture from
+ a query string when displaying the error message (in number of characters).
+ Default: 50.
+ index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list
+ Default: 0
+ alias_post_tablesample (bool): If the table alias comes after tablesample
+ Default: False
+ max_errors (int): Maximum number of error messages to include in a raised ParseError.
+ This is only relevant if error_level is ErrorLevel.RAISE.
+ Default: 3
+ null_ordering (str): Indicates the default null ordering method to use if not explicitly set.
+ Options are "nulls_are_small", "nulls_are_large", "nulls_are_last".
+ Default: "nulls_are_small"
+ """
+
+ FUNCTIONS = {
+ **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()},
+ "DATE_TO_DATE_STR": lambda args: exp.Cast(
+ this=list_get(args, 0),
+ to=exp.DataType(this=exp.DataType.Type.TEXT),
+ ),
+ "TIME_TO_TIME_STR": lambda args: exp.Cast(
+ this=list_get(args, 0),
+ to=exp.DataType(this=exp.DataType.Type.TEXT),
+ ),
+ "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring(
+ this=exp.Cast(
+ this=list_get(args, 0),
+ to=exp.DataType(this=exp.DataType.Type.TEXT),
+ ),
+ start=exp.Literal.number(1),
+ length=exp.Literal.number(10),
+ ),
+ }
+
+ NO_PAREN_FUNCTIONS = {
+ TokenType.CURRENT_DATE: exp.CurrentDate,
+ TokenType.CURRENT_DATETIME: exp.CurrentDate,
+ TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp,
+ }
+
+ NESTED_TYPE_TOKENS = {
+ TokenType.ARRAY,
+ TokenType.MAP,
+ TokenType.STRUCT,
+ TokenType.NULLABLE,
+ }
+
+ TYPE_TOKENS = {
+ TokenType.BOOLEAN,
+ TokenType.TINYINT,
+ TokenType.SMALLINT,
+ TokenType.INT,
+ TokenType.BIGINT,
+ TokenType.FLOAT,
+ TokenType.DOUBLE,
+ TokenType.CHAR,
+ TokenType.NCHAR,
+ TokenType.VARCHAR,
+ TokenType.NVARCHAR,
+ TokenType.TEXT,
+ TokenType.BINARY,
+ TokenType.JSON,
+ TokenType.TIMESTAMP,
+ TokenType.TIMESTAMPTZ,
+ TokenType.DATETIME,
+ TokenType.DATE,
+ TokenType.DECIMAL,
+ TokenType.UUID,
+ TokenType.GEOGRAPHY,
+ *NESTED_TYPE_TOKENS,
+ }
+
+ SUBQUERY_PREDICATES = {
+ TokenType.ANY: exp.Any,
+ TokenType.ALL: exp.All,
+ TokenType.EXISTS: exp.Exists,
+ TokenType.SOME: exp.Any,
+ }
+
+ RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT}
+
+ ID_VAR_TOKENS = {
+ TokenType.VAR,
+ TokenType.ALTER,
+ TokenType.BEGIN,
+ TokenType.BUCKET,
+ TokenType.CACHE,
+ TokenType.COLLATE,
+ TokenType.COMMIT,
+ TokenType.CONSTRAINT,
+ TokenType.CONVERT,
+ TokenType.DEFAULT,
+ TokenType.DELETE,
+ TokenType.ENGINE,
+ TokenType.ESCAPE,
+ TokenType.EXPLAIN,
+ TokenType.FALSE,
+ TokenType.FIRST,
+ TokenType.FOLLOWING,
+ TokenType.FORMAT,
+ TokenType.FUNCTION,
+ TokenType.IF,
+ TokenType.INDEX,
+ TokenType.ISNULL,
+ TokenType.INTERVAL,
+ TokenType.LAZY,
+ TokenType.LOCATION,
+ TokenType.NEXT,
+ TokenType.ONLY,
+ TokenType.OPTIMIZE,
+ TokenType.OPTIONS,
+ TokenType.ORDINALITY,
+ TokenType.PERCENT,
+ TokenType.PRECEDING,
+ TokenType.RANGE,
+ TokenType.REFERENCES,
+ TokenType.ROWS,
+ TokenType.SCHEMA_COMMENT,
+ TokenType.SET,
+ TokenType.SHOW,
+ TokenType.STORED,
+ TokenType.TABLE,
+ TokenType.TABLE_FORMAT,
+ TokenType.TEMPORARY,
+ TokenType.TOP,
+ TokenType.TRUNCATE,
+ TokenType.TRUE,
+ TokenType.UNBOUNDED,
+ TokenType.UNIQUE,
+ TokenType.PROPERTIES,
+ *SUBQUERY_PREDICATES,
+ *TYPE_TOKENS,
+ }
+
+ CASTS = {
+ TokenType.CAST,
+ TokenType.TRY_CAST,
+ }
+
+ FUNC_TOKENS = {
+ TokenType.CONVERT,
+ TokenType.CURRENT_DATE,
+ TokenType.CURRENT_DATETIME,
+ TokenType.CURRENT_TIMESTAMP,
+ TokenType.CURRENT_TIME,
+ TokenType.EXTRACT,
+ TokenType.FILTER,
+ TokenType.FIRST,
+ TokenType.FORMAT,
+ TokenType.ISNULL,
+ TokenType.OFFSET,
+ TokenType.PRIMARY_KEY,
+ TokenType.REPLACE,
+ TokenType.ROW,
+ TokenType.UNNEST,
+ TokenType.VAR,
+ TokenType.LEFT,
+ TokenType.RIGHT,
+ TokenType.DATE,
+ TokenType.DATETIME,
+ TokenType.TIMESTAMP,
+ TokenType.TIMESTAMPTZ,
+ *CASTS,
+ *NESTED_TYPE_TOKENS,
+ *SUBQUERY_PREDICATES,
+ }
+
+ CONJUNCTION = {
+ TokenType.AND: exp.And,
+ TokenType.OR: exp.Or,
+ }
+
+ EQUALITY = {
+ TokenType.EQ: exp.EQ,
+ TokenType.NEQ: exp.NEQ,
+ }
+
+ COMPARISON = {
+ TokenType.GT: exp.GT,
+ TokenType.GTE: exp.GTE,
+ TokenType.LT: exp.LT,
+ TokenType.LTE: exp.LTE,
+ }
+
+ BITWISE = {
+ TokenType.AMP: exp.BitwiseAnd,
+ TokenType.CARET: exp.BitwiseXor,
+ TokenType.PIPE: exp.BitwiseOr,
+ TokenType.DPIPE: exp.DPipe,
+ }
+
+ TERM = {
+ TokenType.DASH: exp.Sub,
+ TokenType.PLUS: exp.Add,
+ TokenType.MOD: exp.Mod,
+ }
+
+ FACTOR = {
+ TokenType.DIV: exp.IntDiv,
+ TokenType.SLASH: exp.Div,
+ TokenType.STAR: exp.Mul,
+ }
+
+ TIMESTAMPS = {
+ TokenType.TIMESTAMP,
+ TokenType.TIMESTAMPTZ,
+ }
+
+ SET_OPERATIONS = {
+ TokenType.UNION,
+ TokenType.INTERSECT,
+ TokenType.EXCEPT,
+ }
+
+ JOIN_SIDES = {
+ TokenType.LEFT,
+ TokenType.RIGHT,
+ TokenType.FULL,
+ }
+
+ JOIN_KINDS = {
+ TokenType.INNER,
+ TokenType.OUTER,
+ TokenType.CROSS,
+ }
+
+ COLUMN_OPERATORS = {
+ TokenType.DOT: None,
+ TokenType.ARROW: lambda self, this, path: self.expression(
+ exp.JSONExtract,
+ this=this,
+ path=path,
+ ),
+ TokenType.DARROW: lambda self, this, path: self.expression(
+ exp.JSONExtractScalar,
+ this=this,
+ path=path,
+ ),
+ TokenType.HASH_ARROW: lambda self, this, path: self.expression(
+ exp.JSONBExtract,
+ this=this,
+ path=path,
+ ),
+ TokenType.DHASH_ARROW: lambda self, this, path: self.expression(
+ exp.JSONBExtractScalar,
+ this=this,
+ path=path,
+ ),
+ }
+
+ EXPRESSION_PARSERS = {
+ exp.DataType: lambda self: self._parse_types(),
+ exp.From: lambda self: self._parse_from(),
+ exp.Group: lambda self: self._parse_group(),
+ exp.Lateral: lambda self: self._parse_lateral(),
+ exp.Join: lambda self: self._parse_join(),
+ exp.Order: lambda self: self._parse_order(),
+ exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
+ exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
+ exp.Lambda: lambda self: self._parse_lambda(),
+ exp.Limit: lambda self: self._parse_limit(),
+ exp.Offset: lambda self: self._parse_offset(),
+ exp.TableAlias: lambda self: self._parse_table_alias(),
+ exp.Table: lambda self: self._parse_table(),
+ exp.Condition: lambda self: self._parse_conjunction(),
+ exp.Expression: lambda self: self._parse_statement(),
+ exp.Properties: lambda self: self._parse_properties(),
+ "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(),
+ }
+
+ STATEMENT_PARSERS = {
+ TokenType.CREATE: lambda self: self._parse_create(),
+ TokenType.DROP: lambda self: self._parse_drop(),
+ TokenType.INSERT: lambda self: self._parse_insert(),
+ TokenType.UPDATE: lambda self: self._parse_update(),
+ TokenType.DELETE: lambda self: self._parse_delete(),
+ TokenType.CACHE: lambda self: self._parse_cache(),
+ TokenType.UNCACHE: lambda self: self._parse_uncache(),
+ }
+
+ PRIMARY_PARSERS = {
+ TokenType.STRING: lambda _, token: exp.Literal.string(token.text),
+ TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text),
+ TokenType.STAR: lambda self, _: exp.Star(
+ **{"except": self._parse_except(), "replace": self._parse_replace()}
+ ),
+ TokenType.NULL: lambda *_: exp.Null(),
+ TokenType.TRUE: lambda *_: exp.Boolean(this=True),
+ TokenType.FALSE: lambda *_: exp.Boolean(this=False),
+ TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(),
+ TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text),
+ TokenType.INTRODUCER: lambda self, token: self.expression(
+ exp.Introducer,
+ this=token.text,
+ expression=self._parse_var_or_string(),
+ ),
+ }
+
+ RANGE_PARSERS = {
+ TokenType.BETWEEN: lambda self, this: self._parse_between(this),
+ TokenType.IN: lambda self, this: self._parse_in(this),
+ TokenType.IS: lambda self, this: self._parse_is(this),
+ TokenType.LIKE: lambda self, this: self._parse_escape(
+ self.expression(exp.Like, this=this, expression=self._parse_type())
+ ),
+ TokenType.ILIKE: lambda self, this: self._parse_escape(
+ self.expression(exp.ILike, this=this, expression=self._parse_type())
+ ),
+ TokenType.RLIKE: lambda self, this: self.expression(
+ exp.RegexpLike, this=this, expression=self._parse_type()
+ ),
+ }
+
+ PROPERTY_PARSERS = {
+ TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(),
+ TokenType.CHARACTER_SET: lambda self: self._parse_character_set(),
+ TokenType.COLLATE: lambda self: self._parse_collate(),
+ TokenType.ENGINE: lambda self: self._parse_engine(),
+ TokenType.FORMAT: lambda self: self._parse_format(),
+ TokenType.LOCATION: lambda self: self.expression(
+ exp.LocationProperty,
+ this=exp.Literal.string("LOCATION"),
+ value=self._parse_string(),
+ ),
+ TokenType.PARTITIONED_BY: lambda self: self.expression(
+ exp.PartitionedByProperty,
+ this=exp.Literal.string("PARTITIONED_BY"),
+ value=self._parse_schema(),
+ ),
+ TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(),
+ TokenType.STORED: lambda self: self._parse_stored(),
+ TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(),
+ TokenType.USING: lambda self: self._parse_table_format(),
+ }
+
+ CONSTRAINT_PARSERS = {
+ TokenType.CHECK: lambda self: self._parse_check(),
+ TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(),
+ TokenType.UNIQUE: lambda self: self._parse_unique(),
+ }
+
+ NO_PAREN_FUNCTION_PARSERS = {
+ TokenType.CASE: lambda self: self._parse_case(),
+ TokenType.IF: lambda self: self._parse_if(),
+ }
+
+ FUNCTION_PARSERS = {
+ TokenType.CONVERT: lambda self, _: self._parse_convert(),
+ TokenType.EXTRACT: lambda self, _: self._parse_extract(),
+ **{
+ token_type: lambda self, token_type: self._parse_cast(
+ self.STRICT_CAST and token_type == TokenType.CAST
+ )
+ for token_type in CASTS
+ },
+ }
+
+ QUERY_MODIFIER_PARSERS = {
+ "laterals": lambda self: self._parse_laterals(),
+ "joins": lambda self: self._parse_joins(),
+ "where": lambda self: self._parse_where(),
+ "group": lambda self: self._parse_group(),
+ "having": lambda self: self._parse_having(),
+ "qualify": lambda self: self._parse_qualify(),
+ "window": lambda self: self._match(TokenType.WINDOW)
+ and self._parse_window(self._parse_id_var(), alias=True),
+ "distribute": lambda self: self._parse_sort(
+ TokenType.DISTRIBUTE_BY, exp.Distribute
+ ),
+ "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort),
+ "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster),
+ "order": lambda self: self._parse_order(),
+ "limit": lambda self: self._parse_limit(),
+ "offset": lambda self: self._parse_offset(),
+ }
+
+ CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX}
+
+ STRICT_CAST = True
+
+ __slots__ = (
+ "error_level",
+ "error_message_context",
+ "sql",
+ "errors",
+ "index_offset",
+ "unnest_column_only",
+ "alias_post_tablesample",
+ "max_errors",
+ "null_ordering",
+ "_tokens",
+ "_chunks",
+ "_index",
+ "_curr",
+ "_next",
+ "_prev",
+ "_greedy_subqueries",
+ )
+
+ def __init__(
+ self,
+ error_level=None,
+ error_message_context=100,
+ index_offset=0,
+ unnest_column_only=False,
+ alias_post_tablesample=False,
+ max_errors=3,
+ null_ordering=None,
+ ):
+ self.error_level = error_level or ErrorLevel.RAISE
+ self.error_message_context = error_message_context
+ self.index_offset = index_offset
+ self.unnest_column_only = unnest_column_only
+ self.alias_post_tablesample = alias_post_tablesample
+ self.max_errors = max_errors
+ self.null_ordering = null_ordering
+ self.reset()
+
+ def reset(self):
+ self.sql = ""
+ self.errors = []
+ self._tokens = []
+ self._chunks = [[]]
+ self._index = 0
+ self._curr = None
+ self._next = None
+ self._prev = None
+ self._greedy_subqueries = False
+
+ def parse(self, raw_tokens, sql=None):
+ """
+ Parses the given list of tokens and returns a list of syntax trees, one tree
+ per parsed SQL statement.
+
+ Args
+ raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`).
+ sql (str): the original SQL string. Used to produce helpful debug messages.
+
+ Returns
+ the list of syntax trees (:class:`~sqlglot.expressions.Expression`).
+ """
+ return self._parse(
+ parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql
+ )
+
+ def parse_into(self, expression_types, raw_tokens, sql=None):
+ for expression_type in ensure_list(expression_types):
+ parser = self.EXPRESSION_PARSERS.get(expression_type)
+ if not parser:
+ raise TypeError(f"No parser registered for {expression_type}")
+ try:
+ return self._parse(parser, raw_tokens, sql)
+ except ParseError as e:
+ error = e
+ raise ParseError(f"Failed to parse into {expression_types}") from error
+
+ def _parse(self, parse_method, raw_tokens, sql=None):
+ self.reset()
+ self.sql = sql or ""
+ total = len(raw_tokens)
+
+ for i, token in enumerate(raw_tokens):
+ if token.token_type == TokenType.SEMICOLON:
+ if i < total - 1:
+ self._chunks.append([])
+ else:
+ self._chunks[-1].append(token)
+
+ expressions = []
+
+ for tokens in self._chunks:
+ self._index = -1
+ self._tokens = tokens
+ self._advance()
+ expressions.append(parse_method(self))
+
+ if self._index < len(self._tokens):
+ self.raise_error("Invalid expression / Unexpected token")
+
+ self.check_errors()
+
+ return expressions
+
+ def check_errors(self):
+ if self.error_level == ErrorLevel.WARN:
+ for error in self.errors:
+ logger.error(str(error))
+ elif self.error_level == ErrorLevel.RAISE and self.errors:
+ raise ParseError(concat_errors(self.errors, self.max_errors))
+
+ def raise_error(self, message, token=None):
+ token = token or self._curr or self._prev or Token.string("")
+ start = self._find_token(token, self.sql)
+ end = start + len(token.text)
+ start_context = self.sql[max(start - self.error_message_context, 0) : start]
+ highlight = self.sql[start:end]
+ end_context = self.sql[end : end + self.error_message_context]
+ error = ParseError(
+ f"{message}. Line {token.line}, Col: {token.col}.\n"
+ f" {start_context}\033[4m{highlight}\033[0m{end_context}"
+ )
+ if self.error_level == ErrorLevel.IMMEDIATE:
+ raise error
+ self.errors.append(error)
+
+ def expression(self, exp_class, **kwargs):
+ instance = exp_class(**kwargs)
+ self.validate_expression(instance)
+ return instance
+
+ def validate_expression(self, expression, args=None):
+ if self.error_level == ErrorLevel.IGNORE:
+ return
+
+ for k in expression.args:
+ if k not in expression.arg_types:
+ self.raise_error(
+ f"Unexpected keyword: '{k}' for {expression.__class__}"
+ )
+ for k, mandatory in expression.arg_types.items():
+ v = expression.args.get(k)
+ if mandatory and (v is None or (isinstance(v, list) and not v)):
+ self.raise_error(
+ f"Required keyword: '{k}' missing for {expression.__class__}"
+ )
+
+ if (
+ args
+ and len(args) > len(expression.arg_types)
+ and not expression.is_var_len_args
+ ):
+ self.raise_error(
+ f"The number of provided arguments ({len(args)}) is greater than "
+ f"the maximum number of supported arguments ({len(expression.arg_types)})"
+ )
+
+ def _find_token(self, token, sql):
+ line = 1
+ col = 1
+ index = 0
+
+ while line < token.line or col < token.col:
+ if Tokenizer.WHITE_SPACE.get(sql[index]) == TokenType.BREAK:
+ line += 1
+ col = 1
+ else:
+ col += 1
+ index += 1
+
+ return index
+
+ def _get_token(self, index):
+ return list_get(self._tokens, index)
+
+ def _advance(self, times=1):
+ self._index += times
+ self._curr = self._get_token(self._index)
+ self._next = self._get_token(self._index + 1)
+ self._prev = self._get_token(self._index - 1) if self._index > 0 else None
+
+ def _retreat(self, index):
+ self._advance(index - self._index)
+
+ def _parse_statement(self):
+ if self._curr is None:
+ return None
+
+ if self._match_set(self.STATEMENT_PARSERS):
+ return self.STATEMENT_PARSERS[self._prev.token_type](self)
+
+ if self._match_set(Tokenizer.COMMANDS):
+ return self.expression(
+ exp.Command,
+ this=self._prev.text,
+ expression=self._parse_string(),
+ )
+
+ expression = self._parse_expression()
+ expression = (
+ self._parse_set_operations(expression)
+ if expression
+ else self._parse_select()
+ )
+ self._parse_query_modifiers(expression)
+ return expression
+
+ def _parse_drop(self):
+ if self._match(TokenType.TABLE):
+ kind = "TABLE"
+ elif self._match(TokenType.VIEW):
+ kind = "VIEW"
+ else:
+ self.raise_error("Expected TABLE or View")
+
+ return self.expression(
+ exp.Drop,
+ exists=self._parse_exists(),
+ this=self._parse_table(schema=True),
+ kind=kind,
+ )
+
+ def _parse_exists(self, not_=False):
+ return (
+ self._match(TokenType.IF)
+ and (not not_ or self._match(TokenType.NOT))
+ and self._match(TokenType.EXISTS)
+ )
+
+ def _parse_create(self):
+ replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE)
+ temporary = self._match(TokenType.TEMPORARY)
+ unique = self._match(TokenType.UNIQUE)
+
+ create_token = self._match_set(self.CREATABLES) and self._prev
+
+ if not create_token:
+ self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION")
+
+ exists = self._parse_exists(not_=True)
+ this = None
+ expression = None
+ properties = None
+
+ if create_token.token_type == TokenType.FUNCTION:
+ this = self._parse_var()
+ if self._match(TokenType.ALIAS):
+ expression = self._parse_string()
+ elif create_token.token_type == TokenType.INDEX:
+ this = self._parse_index()
+ elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW):
+ this = self._parse_table(schema=True)
+ properties = self._parse_properties(
+ this if isinstance(this, exp.Schema) else None
+ )
+ if self._match(TokenType.ALIAS):
+ expression = self._parse_select()
+
+ return self.expression(
+ exp.Create,
+ this=this,
+ kind=create_token.text,
+ expression=expression,
+ exists=exists,
+ properties=properties,
+ temporary=temporary,
+ replace=replace,
+ unique=unique,
+ )
+
+ def _parse_property(self, schema):
+ if self._match_set(self.PROPERTY_PARSERS):
+ return self.PROPERTY_PARSERS[self._prev.token_type](self)
+ if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET):
+ return self._parse_character_set(True)
+
+ if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False):
+ key = self._parse_var().this
+ self._match(TokenType.EQ)
+
+ if key.upper() == "PARTITIONED_BY":
+ expression = exp.PartitionedByProperty
+ value = self._parse_schema() or self._parse_bracket(self._parse_field())
+
+ if schema and not isinstance(value, exp.Schema):
+ columns = {v.name.upper() for v in value.expressions}
+ partitions = [
+ expression
+ for expression in schema.expressions
+ if expression.this.name.upper() in columns
+ ]
+ schema.set(
+ "expressions",
+ [e for e in schema.expressions if e not in partitions],
+ )
+ value = self.expression(exp.Schema, expressions=partitions)
+ else:
+ value = self._parse_column()
+ expression = exp.AnonymousProperty
+
+ return self.expression(
+ expression,
+ this=exp.Literal.string(key),
+ value=value,
+ )
+ return None
+
+ def _parse_stored(self):
+ self._match(TokenType.ALIAS)
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.FileFormatProperty,
+ this=exp.Literal.string("FORMAT"),
+ value=exp.Literal.string(self._parse_var().name),
+ )
+
+ def _parse_format(self):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.FileFormatProperty,
+ this=exp.Literal.string("FORMAT"),
+ value=self._parse_string() or self._parse_var(),
+ )
+
+ def _parse_engine(self):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.EngineProperty,
+ this=exp.Literal.string("ENGINE"),
+ value=self._parse_var_or_string(),
+ )
+
+ def _parse_auto_increment(self):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.AutoIncrementProperty,
+ this=exp.Literal.string("AUTO_INCREMENT"),
+ value=self._parse_var() or self._parse_number(),
+ )
+
+ def _parse_collate(self):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.CollateProperty,
+ this=exp.Literal.string("COLLATE"),
+ value=self._parse_var_or_string(),
+ )
+
+ def _parse_schema_comment(self):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.SchemaCommentProperty,
+ this=exp.Literal.string("COMMENT"),
+ value=self._parse_string(),
+ )
+
+ def _parse_character_set(self, default=False):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.CharacterSetProperty,
+ this=exp.Literal.string("CHARACTER_SET"),
+ value=self._parse_var_or_string(),
+ default=default,
+ )
+
+ def _parse_table_format(self):
+ self._match(TokenType.EQ)
+ return self.expression(
+ exp.TableFormatProperty,
+ this=exp.Literal.string("TABLE_FORMAT"),
+ value=self._parse_var_or_string(),
+ )
+
+ def _parse_properties(self, schema=None):
+ """
+ Schema is included since if the table schema is defined and we later get a partition by expression
+ then we will define those columns in the partition by section and not in with the rest of the
+ columns
+ """
+ properties = []
+
+ while True:
+ if self._match(TokenType.WITH):
+ self._match_l_paren()
+ properties.extend(self._parse_csv(lambda: self._parse_property(schema)))
+ self._match_r_paren()
+ elif self._match(TokenType.PROPERTIES):
+ self._match_l_paren()
+ properties.extend(
+ self._parse_csv(
+ lambda: self.expression(
+ exp.AnonymousProperty,
+ this=self._parse_string(),
+ value=self._match(TokenType.EQ) and self._parse_string(),
+ )
+ )
+ )
+ self._match_r_paren()
+ else:
+ identified_property = self._parse_property(schema)
+ if not identified_property:
+ break
+ properties.append(identified_property)
+ if properties:
+ return self.expression(exp.Properties, expressions=properties)
+ return None
+
+ def _parse_insert(self):
+ overwrite = self._match(TokenType.OVERWRITE)
+ self._match(TokenType.INTO)
+ self._match(TokenType.TABLE)
+ return self.expression(
+ exp.Insert,
+ this=self._parse_table(schema=True),
+ exists=self._parse_exists(),
+ partition=self._parse_partition(),
+ expression=self._parse_select(),
+ overwrite=overwrite,
+ )
+
+ def _parse_delete(self):
+ self._match(TokenType.FROM)
+
+ return self.expression(
+ exp.Delete,
+ this=self._parse_table(schema=True),
+ where=self._parse_where(),
+ )
+
+ def _parse_update(self):
+ return self.expression(
+ exp.Update,
+ **{
+ "this": self._parse_table(schema=True),
+ "expressions": self._match(TokenType.SET)
+ and self._parse_csv(self._parse_equality),
+ "from": self._parse_from(),
+ "where": self._parse_where(),
+ },
+ )
+
+ def _parse_uncache(self):
+ if not self._match(TokenType.TABLE):
+ self.raise_error("Expecting TABLE after UNCACHE")
+ return self.expression(
+ exp.Uncache,
+ exists=self._parse_exists(),
+ this=self._parse_table(schema=True),
+ )
+
+ def _parse_cache(self):
+ lazy = self._match(TokenType.LAZY)
+ self._match(TokenType.TABLE)
+ table = self._parse_table(schema=True)
+ options = []
+
+ if self._match(TokenType.OPTIONS):
+ self._match_l_paren()
+ k = self._parse_string()
+ self._match(TokenType.EQ)
+ v = self._parse_string()
+ options = [k, v]
+ self._match_r_paren()
+
+ self._match(TokenType.ALIAS)
+ return self.expression(
+ exp.Cache,
+ this=table,
+ lazy=lazy,
+ options=options,
+ expression=self._parse_select(),
+ )
+
+ def _parse_partition(self):
+ if not self._match(TokenType.PARTITION):
+ return None
+
+ def parse_values():
+ k = self._parse_var()
+ if self._match(TokenType.EQ):
+ v = self._parse_string()
+ return (k, v)
+ return (k, None)
+
+ self._match_l_paren()
+ values = self._parse_csv(parse_values)
+ self._match_r_paren()
+
+ return self.expression(
+ exp.Partition,
+ this=values,
+ )
+
+ def _parse_value(self):
+ self._match_l_paren()
+ expressions = self._parse_csv(self._parse_conjunction)
+ self._match_r_paren()
+ return self.expression(exp.Tuple, expressions=expressions)
+
+ def _parse_select(self, table=None):
+ index = self._index
+
+ if self._match(TokenType.SELECT):
+ hint = self._parse_hint()
+ all_ = self._match(TokenType.ALL)
+ distinct = self._match(TokenType.DISTINCT)
+
+ if distinct:
+ distinct = self.expression(
+ exp.Distinct,
+ on=self._parse_value() if self._match(TokenType.ON) else None,
+ )
+
+ if all_ and distinct:
+ self.raise_error("Cannot specify both ALL and DISTINCT after SELECT")
+
+ limit = self._parse_limit(top=True)
+ expressions = self._parse_csv(
+ lambda: self._parse_annotation(self._parse_expression())
+ )
+
+ this = self.expression(
+ exp.Select,
+ hint=hint,
+ distinct=distinct,
+ expressions=expressions,
+ limit=limit,
+ )
+ from_ = self._parse_from()
+ if from_:
+ this.set("from", from_)
+ self._parse_query_modifiers(this)
+ elif self._match(TokenType.WITH):
+ recursive = self._match(TokenType.RECURSIVE)
+
+ expressions = []
+
+ while True:
+ expressions.append(self._parse_cte())
+
+ if not self._match(TokenType.COMMA):
+ break
+
+ cte = self.expression(
+ exp.With,
+ expressions=expressions,
+ recursive=recursive,
+ )
+ this = self._parse_statement()
+
+ if not this:
+ self.raise_error("Failed to parse any statement following CTE")
+ return cte
+
+ if "with" in this.arg_types:
+ this.set(
+ "with",
+ self.expression(
+ exp.With,
+ expressions=expressions,
+ recursive=recursive,
+ ),
+ )
+ else:
+ self.raise_error(f"{this.key} does not support CTE")
+ elif self._match(TokenType.L_PAREN):
+ this = self._parse_table() if table else self._parse_select()
+
+ if this:
+ self._parse_query_modifiers(this)
+ self._match_r_paren()
+ this = self._parse_subquery(this)
+ else:
+ self._retreat(index)
+ elif self._match(TokenType.VALUES):
+ this = self.expression(
+ exp.Values, expressions=self._parse_csv(self._parse_value)
+ )
+ alias = self._parse_table_alias()
+ if alias:
+ this = self.expression(exp.Subquery, this=this, alias=alias)
+ else:
+ this = None
+
+ return self._parse_set_operations(this) if this else None
+
+ def _parse_cte(self):
+ alias = self._parse_table_alias()
+ if not alias or not alias.this:
+ self.raise_error("Expected CTE to have alias")
+
+ if not self._match(TokenType.ALIAS):
+ self.raise_error("Expected AS in CTE")
+
+ self._match_l_paren()
+ expression = self._parse_statement()
+ self._match_r_paren()
+
+ return self.expression(
+ exp.CTE,
+ this=expression,
+ alias=alias,
+ )
+
+ def _parse_table_alias(self):
+ any_token = self._match(TokenType.ALIAS)
+ alias = self._parse_id_var(any_token)
+ columns = None
+
+ if self._match(TokenType.L_PAREN):
+ columns = self._parse_csv(lambda: self._parse_id_var(any_token))
+ self._match_r_paren()
+
+ if not alias and not columns:
+ return None
+
+ return self.expression(
+ exp.TableAlias,
+ this=alias,
+ columns=columns,
+ )
+
+ def _parse_subquery(self, this):
+ return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias())
+
+ def _parse_query_modifiers(self, this):
+ if not isinstance(this, (exp.Subquery, exp.Subqueryable)):
+ return
+
+ for key, parser in self.QUERY_MODIFIER_PARSERS.items():
+ expression = parser(self)
+
+ if expression:
+ this.set(key, expression)
+
+ def _parse_annotation(self, expression):
+ if self._match(TokenType.ANNOTATION):
+ return self.expression(
+ exp.Annotation, this=self._prev.text, expression=expression
+ )
+
+ return expression
+
+ def _parse_hint(self):
+ if self._match(TokenType.HINT):
+ hints = self._parse_csv(self._parse_function)
+ if not self._match(TokenType.HINT):
+ self.raise_error("Expected */ after HINT")
+ return self.expression(exp.Hint, expressions=hints)
+ return None
+
+ def _parse_from(self):
+ if not self._match(TokenType.FROM):
+ return None
+
+ return self.expression(exp.From, expressions=self._parse_csv(self._parse_table))
+
+ def _parse_laterals(self):
+ return self._parse_all(self._parse_lateral)
+
+ def _parse_lateral(self):
+ if not self._match(TokenType.LATERAL):
+ return None
+
+ if not self._match(TokenType.VIEW):
+ self.raise_error("Expected VIEW after LATERAL")
+
+ outer = self._match(TokenType.OUTER)
+
+ return self.expression(
+ exp.Lateral,
+ this=self._parse_function(),
+ outer=outer,
+ alias=self.expression(
+ exp.TableAlias,
+ this=self._parse_id_var(any_token=False),
+ columns=(
+ self._parse_csv(self._parse_id_var)
+ if self._match(TokenType.ALIAS)
+ else None
+ ),
+ ),
+ )
+
+ def _parse_joins(self):
+ return self._parse_all(self._parse_join)
+
+ def _parse_join_side_and_kind(self):
+ return (
+ self._match_set(self.JOIN_SIDES) and self._prev,
+ self._match_set(self.JOIN_KINDS) and self._prev,
+ )
+
+ def _parse_join(self):
+ side, kind = self._parse_join_side_and_kind()
+
+ if not self._match(TokenType.JOIN):
+ return None
+
+ kwargs = {"this": self._parse_table()}
+
+ if side:
+ kwargs["side"] = side.text
+ if kind:
+ kwargs["kind"] = kind.text
+
+ if self._match(TokenType.ON):
+ kwargs["on"] = self._parse_conjunction()
+ elif self._match(TokenType.USING):
+ kwargs["using"] = self._parse_wrapped_id_vars()
+
+ return self.expression(exp.Join, **kwargs)
+
+ def _parse_index(self):
+ index = self._parse_id_var()
+ self._match(TokenType.ON)
+ self._match(TokenType.TABLE) # hive
+ return self.expression(
+ exp.Index,
+ this=index,
+ table=self.expression(exp.Table, this=self._parse_id_var()),
+ columns=self._parse_expression(),
+ )
+
+ def _parse_table(self, schema=False):
+ unnest = self._parse_unnest()
+
+ if unnest:
+ return unnest
+
+ subquery = self._parse_select(table=True)
+
+ if subquery:
+ return subquery
+
+ catalog = None
+ db = None
+ table = (not schema and self._parse_function()) or self._parse_id_var(False)
+
+ while self._match(TokenType.DOT):
+ catalog = db
+ db = table
+ table = self._parse_id_var()
+
+ if not table:
+ self.raise_error("Expected table name")
+
+ this = self.expression(exp.Table, this=table, db=db, catalog=catalog)
+
+ if schema:
+ return self._parse_schema(this=this)
+
+ if self.alias_post_tablesample:
+ table_sample = self._parse_table_sample()
+
+ alias = self._parse_table_alias()
+
+ if alias:
+ this = self.expression(exp.Alias, this=this, alias=alias)
+
+ if not self.alias_post_tablesample:
+ table_sample = self._parse_table_sample()
+
+ if table_sample:
+ table_sample.set("this", this)
+ this = table_sample
+
+ return this
+
+ def _parse_unnest(self):
+ if not self._match(TokenType.UNNEST):
+ return None
+
+ self._match_l_paren()
+ expressions = self._parse_csv(self._parse_column)
+ self._match_r_paren()
+
+ ordinality = bool(
+ self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY)
+ )
+
+ alias = self._parse_table_alias()
+
+ if alias and self.unnest_column_only:
+ if alias.args.get("columns"):
+ self.raise_error("Unexpected extra column alias in unnest.")
+ alias.set("columns", [alias.this])
+ alias.set("this", None)
+
+ return self.expression(
+ exp.Unnest,
+ expressions=expressions,
+ ordinality=ordinality,
+ alias=alias,
+ )
+
+ def _parse_table_sample(self):
+ if not self._match(TokenType.TABLE_SAMPLE):
+ return None
+
+ method = self._parse_var()
+ bucket_numerator = None
+ bucket_denominator = None
+ bucket_field = None
+ percent = None
+ rows = None
+ size = None
+
+ self._match_l_paren()
+
+ if self._match(TokenType.BUCKET):
+ bucket_numerator = self._parse_number()
+ self._match(TokenType.OUT_OF)
+ bucket_denominator = bucket_denominator = self._parse_number()
+ self._match(TokenType.ON)
+ bucket_field = self._parse_field()
+ else:
+ num = self._parse_number()
+
+ if self._match(TokenType.PERCENT):
+ percent = num
+ elif self._match(TokenType.ROWS):
+ rows = num
+ else:
+ size = num
+
+ self._match_r_paren()
+
+ return self.expression(
+ exp.TableSample,
+ method=method,
+ bucket_numerator=bucket_numerator,
+ bucket_denominator=bucket_denominator,
+ bucket_field=bucket_field,
+ percent=percent,
+ rows=rows,
+ size=size,
+ )
+
+ def _parse_where(self):
+ if not self._match(TokenType.WHERE):
+ return None
+ return self.expression(exp.Where, this=self._parse_conjunction())
+
+ def _parse_group(self):
+ if not self._match(TokenType.GROUP_BY):
+ return None
+ return self.expression(
+ exp.Group,
+ expressions=self._parse_csv(self._parse_conjunction),
+ grouping_sets=self._parse_grouping_sets(),
+ cube=self._match(TokenType.CUBE) and self._parse_wrapped_id_vars(),
+ rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(),
+ )
+
+ def _parse_grouping_sets(self):
+ if not self._match(TokenType.GROUPING_SETS):
+ return None
+
+ self._match_l_paren()
+ grouping_sets = self._parse_csv(self._parse_grouping_set)
+ self._match_r_paren()
+ return grouping_sets
+
+ def _parse_grouping_set(self):
+ if self._match(TokenType.L_PAREN):
+ grouping_set = self._parse_csv(self._parse_id_var)
+ self._match_r_paren()
+ return self.expression(exp.Tuple, expressions=grouping_set)
+ return self._parse_id_var()
+
+ def _parse_having(self):
+ if not self._match(TokenType.HAVING):
+ return None
+ return self.expression(exp.Having, this=self._parse_conjunction())
+
+ def _parse_qualify(self):
+ if not self._match(TokenType.QUALIFY):
+ return None
+ return self.expression(exp.Qualify, this=self._parse_conjunction())
+
+ def _parse_order(self, this=None):
+ if not self._match(TokenType.ORDER_BY):
+ return this
+
+ return self.expression(
+ exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)
+ )
+
+ def _parse_sort(self, token_type, exp_class):
+ if not self._match(token_type):
+ return None
+
+ return self.expression(
+ exp_class, expressions=self._parse_csv(self._parse_ordered)
+ )
+
+ def _parse_ordered(self):
+ this = self._parse_conjunction()
+ self._match(TokenType.ASC)
+ is_desc = self._match(TokenType.DESC)
+ is_nulls_first = self._match(TokenType.NULLS_FIRST)
+ is_nulls_last = self._match(TokenType.NULLS_LAST)
+ desc = is_desc or False
+ asc = not desc
+ nulls_first = is_nulls_first or False
+ explicitly_null_ordered = is_nulls_first or is_nulls_last
+ if (
+ not explicitly_null_ordered
+ and (
+ (asc and self.null_ordering == "nulls_are_small")
+ or (desc and self.null_ordering != "nulls_are_small")
+ )
+ and self.null_ordering != "nulls_are_last"
+ ):
+ nulls_first = True
+
+ return self.expression(
+ exp.Ordered, this=this, desc=desc, nulls_first=nulls_first
+ )
+
+ def _parse_limit(self, this=None, top=False):
+ if self._match(TokenType.TOP if top else TokenType.LIMIT):
+ return self.expression(
+ exp.Limit, this=this, expression=self._parse_number()
+ )
+ if self._match(TokenType.FETCH):
+ direction = self._match_set((TokenType.FIRST, TokenType.NEXT))
+ direction = self._prev.text if direction else "FIRST"
+ count = self._parse_number()
+ self._match_set((TokenType.ROW, TokenType.ROWS))
+ self._match(TokenType.ONLY)
+ return self.expression(exp.Fetch, direction=direction, count=count)
+ return this
+
+ def _parse_offset(self, this=None):
+ if not self._match(TokenType.OFFSET):
+ return this
+ count = self._parse_number()
+ self._match_set((TokenType.ROW, TokenType.ROWS))
+ return self.expression(exp.Offset, this=this, expression=count)
+
+ def _parse_set_operations(self, this):
+ if not self._match_set(self.SET_OPERATIONS):
+ return this
+
+ token_type = self._prev.token_type
+
+ if token_type == TokenType.UNION:
+ expression = exp.Union
+ elif token_type == TokenType.EXCEPT:
+ expression = exp.Except
+ else:
+ expression = exp.Intersect
+
+ return self.expression(
+ expression,
+ this=this,
+ distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL),
+ expression=self._parse_select(),
+ )
+
+ def _parse_expression(self):
+ return self._parse_alias(self._parse_conjunction())
+
+ def _parse_conjunction(self):
+ return self._parse_tokens(self._parse_equality, self.CONJUNCTION)
+
+ def _parse_equality(self):
+ return self._parse_tokens(self._parse_comparison, self.EQUALITY)
+
+ def _parse_comparison(self):
+ return self._parse_tokens(self._parse_range, self.COMPARISON)
+
+ def _parse_range(self):
+ this = self._parse_bitwise()
+ negate = self._match(TokenType.NOT)
+
+ if self._match_set(self.RANGE_PARSERS):
+ this = self.RANGE_PARSERS[self._prev.token_type](self, this)
+
+ if negate:
+ this = self.expression(exp.Not, this=this)
+
+ return this
+
+ def _parse_is(self, this):
+ negate = self._match(TokenType.NOT)
+ this = self.expression(
+ exp.Is,
+ this=this,
+ expression=self._parse_null() or self._parse_boolean(),
+ )
+ return self.expression(exp.Not, this=this) if negate else this
+
+ def _parse_in(self, this):
+ unnest = self._parse_unnest()
+ if unnest:
+ this = self.expression(exp.In, this=this, unnest=unnest)
+ else:
+ self._match_l_paren()
+ expressions = self._parse_csv(
+ lambda: self._parse_select() or self._parse_expression()
+ )
+
+ if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable):
+ this = self.expression(exp.In, this=this, query=expressions[0])
+ else:
+ this = self.expression(exp.In, this=this, expressions=expressions)
+
+ self._match_r_paren()
+ return this
+
+ def _parse_between(self, this):
+ low = self._parse_bitwise()
+ self._match(TokenType.AND)
+ high = self._parse_bitwise()
+ return self.expression(exp.Between, this=this, low=low, high=high)
+
+ def _parse_escape(self, this):
+ if not self._match(TokenType.ESCAPE):
+ return this
+ return self.expression(exp.Escape, this=this, expression=self._parse_string())
+
+ def _parse_bitwise(self):
+ this = self._parse_term()
+
+ while True:
+ if self._match_set(self.BITWISE):
+ this = self.expression(
+ self.BITWISE[self._prev.token_type],
+ this=this,
+ expression=self._parse_term(),
+ )
+ elif self._match_pair(TokenType.LT, TokenType.LT):
+ this = self.expression(
+ exp.BitwiseLeftShift, this=this, expression=self._parse_term()
+ )
+ elif self._match_pair(TokenType.GT, TokenType.GT):
+ this = self.expression(
+ exp.BitwiseRightShift, this=this, expression=self._parse_term()
+ )
+ else:
+ break
+
+ return this
+
+ def _parse_term(self):
+ return self._parse_tokens(self._parse_factor, self.TERM)
+
+ def _parse_factor(self):
+ return self._parse_tokens(self._parse_unary, self.FACTOR)
+
+ def _parse_unary(self):
+ if self._match(TokenType.NOT):
+ return self.expression(exp.Not, this=self._parse_equality())
+ if self._match(TokenType.TILDA):
+ return self.expression(exp.BitwiseNot, this=self._parse_unary())
+ if self._match(TokenType.DASH):
+ return self.expression(exp.Neg, this=self._parse_unary())
+ return self._parse_at_time_zone(self._parse_type())
+
+ def _parse_type(self):
+ if self._match(TokenType.INTERVAL):
+ return self.expression(
+ exp.Interval,
+ this=self._parse_term(),
+ unit=self._parse_var(),
+ )
+
+ index = self._index
+ type_token = self._parse_types()
+ this = self._parse_column()
+
+ if type_token:
+ if this:
+ return self.expression(exp.Cast, this=this, to=type_token)
+ if not type_token.args.get("expressions"):
+ self._retreat(index)
+ return self._parse_column()
+ return type_token
+
+ while self._match(TokenType.DCOLON):
+ type_token = self._parse_types()
+ if not type_token:
+ self.raise_error("Expected type")
+ this = self.expression(exp.Cast, this=this, to=type_token)
+
+ return this
+
+ def _parse_types(self):
+ index = self._index
+
+ if not self._match_set(self.TYPE_TOKENS):
+ return None
+
+ type_token = self._prev.token_type
+ nested = type_token in self.NESTED_TYPE_TOKENS
+ is_struct = type_token == TokenType.STRUCT
+ expressions = None
+
+ if self._match(TokenType.L_BRACKET):
+ self._retreat(index)
+ return None
+
+ if self._match(TokenType.L_PAREN):
+ if is_struct:
+ expressions = self._parse_csv(self._parse_struct_kwargs)
+ elif nested:
+ expressions = self._parse_csv(self._parse_types)
+ else:
+ expressions = self._parse_csv(self._parse_number)
+
+ if not expressions:
+ self._retreat(index)
+ return None
+
+ self._match_r_paren()
+
+ if nested and self._match(TokenType.LT):
+ if is_struct:
+ expressions = self._parse_csv(self._parse_struct_kwargs)
+ else:
+ expressions = self._parse_csv(self._parse_types)
+
+ if not self._match(TokenType.GT):
+ self.raise_error("Expecting >")
+
+ if type_token in self.TIMESTAMPS:
+ tz = self._match(TokenType.WITH_TIME_ZONE)
+ self._match(TokenType.WITHOUT_TIME_ZONE)
+ if tz:
+ return exp.DataType(
+ this=exp.DataType.Type.TIMESTAMPTZ,
+ expressions=expressions,
+ )
+ return exp.DataType(
+ this=exp.DataType.Type.TIMESTAMP,
+ expressions=expressions,
+ )
+
+ return exp.DataType(
+ this=exp.DataType.Type[type_token.value.upper()],
+ expressions=expressions,
+ nested=nested,
+ )
+
+ def _parse_struct_kwargs(self):
+ this = self._parse_id_var()
+ self._match(TokenType.COLON)
+ data_type = self._parse_types()
+ if not data_type:
+ return None
+ return self.expression(exp.StructKwarg, this=this, expression=data_type)
+
+ def _parse_at_time_zone(self, this):
+ if not self._match(TokenType.AT_TIME_ZONE):
+ return this
+
+ return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary())
+
+ def _parse_column(self):
+ this = self._parse_field()
+ if isinstance(this, exp.Identifier):
+ this = self.expression(exp.Column, this=this)
+ elif not this:
+ return self._parse_bracket(this)
+ this = self._parse_bracket(this)
+
+ while self._match_set(self.COLUMN_OPERATORS):
+ op = self.COLUMN_OPERATORS.get(self._prev.token_type)
+ field = self._parse_star() or self._parse_function() or self._parse_id_var()
+
+ if isinstance(field, exp.Func):
+ # bigquery allows function calls like x.y.count(...)
+ # SAFE.SUBSTR(...)
+ # https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules
+ this = self._replace_columns_with_dots(this)
+
+ if op:
+ this = op(self, this, exp.Literal.string(field.name))
+ elif isinstance(this, exp.Column) and not this.table:
+ this = self.expression(exp.Column, this=field, table=this.this)
+ else:
+ this = self.expression(exp.Dot, this=this, expression=field)
+ this = self._parse_bracket(this)
+
+ return this
+
+ def _parse_primary(self):
+ if self._match_set(self.PRIMARY_PARSERS):
+ return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev)
+
+ if self._match(TokenType.L_PAREN):
+ query = self._parse_select()
+
+ if query:
+ expressions = [query]
+ else:
+ expressions = self._parse_csv(
+ lambda: self._parse_alias(self._parse_conjunction(), explicit=True)
+ )
+
+ this = list_get(expressions, 0)
+ self._parse_query_modifiers(this)
+ self._match_r_paren()
+
+ if isinstance(this, exp.Subqueryable):
+ return self._parse_subquery(this)
+ if len(expressions) > 1:
+ return self.expression(exp.Tuple, expressions=expressions)
+ return self.expression(exp.Paren, this=this)
+
+ return None
+
+ def _parse_field(self, any_token=False):
+ return (
+ self._parse_primary()
+ or self._parse_function()
+ or self._parse_id_var(any_token)
+ )
+
+ def _parse_function(self):
+ if not self._curr:
+ return None
+
+ token_type = self._curr.token_type
+
+ if self._match_set(self.NO_PAREN_FUNCTION_PARSERS):
+ return self.NO_PAREN_FUNCTION_PARSERS[token_type](self)
+
+ if not self._next or self._next.token_type != TokenType.L_PAREN:
+ if token_type in self.NO_PAREN_FUNCTIONS:
+ return self.expression(
+ self._advance() or self.NO_PAREN_FUNCTIONS[token_type]
+ )
+ return None
+
+ if token_type not in self.FUNC_TOKENS:
+ return None
+
+ if self._match_set(self.FUNCTION_PARSERS):
+ self._advance()
+ this = self.FUNCTION_PARSERS[token_type](self, token_type)
+ else:
+ subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type)
+ this = self._curr.text
+ self._advance(2)
+
+ if subquery_predicate and self._curr.token_type in (
+ TokenType.SELECT,
+ TokenType.WITH,
+ ):
+ this = self.expression(subquery_predicate, this=self._parse_select())
+ self._match_r_paren()
+ return this
+
+ function = self.FUNCTIONS.get(this.upper())
+ args = self._parse_csv(self._parse_lambda)
+
+ if function:
+ this = function(args)
+ self.validate_expression(this, args)
+ else:
+ this = self.expression(exp.Anonymous, this=this, expressions=args)
+ self._match_r_paren()
+ return self._parse_window(this)
+
+ def _parse_lambda(self):
+ index = self._index
+
+ if self._match(TokenType.L_PAREN):
+ expressions = self._parse_csv(self._parse_id_var)
+ self._match(TokenType.R_PAREN)
+ else:
+ expressions = [self._parse_id_var()]
+
+ if not self._match(TokenType.ARROW):
+ self._retreat(index)
+
+ distinct = self._match(TokenType.DISTINCT)
+ this = self._parse_conjunction()
+
+ if distinct:
+ this = self.expression(exp.Distinct, this=this)
+
+ if self._match(TokenType.IGNORE_NULLS):
+ this = self.expression(exp.IgnoreNulls, this=this)
+ else:
+ self._match(TokenType.RESPECT_NULLS)
+
+ return self._parse_alias(self._parse_limit(self._parse_order(this)))
+
+ return self.expression(
+ exp.Lambda,
+ this=self._parse_conjunction(),
+ expressions=expressions,
+ )
+
+ def _parse_schema(self, this=None):
+ index = self._index
+ if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT):
+ self._retreat(index)
+ return this
+
+ args = self._parse_csv(
+ lambda: self._parse_constraint()
+ or self._parse_column_def(self._parse_field())
+ )
+ self._match_r_paren()
+ return self.expression(exp.Schema, this=this, expressions=args)
+
+ def _parse_column_def(self, this):
+ kind = self._parse_types()
+
+ if not kind:
+ return this
+
+ constraints = []
+ while True:
+ constraint = self._parse_column_constraint()
+ if not constraint:
+ break
+ constraints.append(constraint)
+
+ return self.expression(
+ exp.ColumnDef, this=this, kind=kind, constraints=constraints
+ )
+
+ def _parse_column_constraint(self):
+ kind = None
+ this = None
+
+ if self._match(TokenType.CONSTRAINT):
+ this = self._parse_id_var()
+
+ if self._match(TokenType.AUTO_INCREMENT):
+ kind = exp.AutoIncrementColumnConstraint()
+ elif self._match(TokenType.CHECK):
+ self._match_l_paren()
+ kind = self.expression(
+ exp.CheckColumnConstraint, this=self._parse_conjunction()
+ )
+ self._match_r_paren()
+ elif self._match(TokenType.COLLATE):
+ kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var())
+ elif self._match(TokenType.DEFAULT):
+ kind = self.expression(
+ exp.DefaultColumnConstraint, this=self._parse_field()
+ )
+ elif self._match(TokenType.NOT) and self._match(TokenType.NULL):
+ kind = exp.NotNullColumnConstraint()
+ elif self._match(TokenType.SCHEMA_COMMENT):
+ kind = self.expression(
+ exp.CommentColumnConstraint, this=self._parse_string()
+ )
+ elif self._match(TokenType.PRIMARY_KEY):
+ kind = exp.PrimaryKeyColumnConstraint()
+ elif self._match(TokenType.UNIQUE):
+ kind = exp.UniqueColumnConstraint()
+
+ if kind is None:
+ return None
+
+ return self.expression(exp.ColumnConstraint, this=this, kind=kind)
+
+ def _parse_constraint(self):
+ if not self._match(TokenType.CONSTRAINT):
+ return self._parse_unnamed_constraint()
+
+ this = self._parse_id_var()
+ expressions = []
+
+ while True:
+ constraint = self._parse_unnamed_constraint() or self._parse_function()
+ if not constraint:
+ break
+ expressions.append(constraint)
+
+ return self.expression(exp.Constraint, this=this, expressions=expressions)
+
+ def _parse_unnamed_constraint(self):
+ if not self._match_set(self.CONSTRAINT_PARSERS):
+ return None
+
+ return self.CONSTRAINT_PARSERS[self._prev.token_type](self)
+
+ def _parse_check(self):
+ self._match(TokenType.CHECK)
+ self._match_l_paren()
+ expression = self._parse_conjunction()
+ self._match_r_paren()
+
+ return self.expression(exp.Check, this=expression)
+
+ def _parse_unique(self):
+ self._match(TokenType.UNIQUE)
+ columns = self._parse_wrapped_id_vars()
+
+ return self.expression(exp.Unique, expressions=columns)
+
+ def _parse_foreign_key(self):
+ self._match(TokenType.FOREIGN_KEY)
+
+ expressions = self._parse_wrapped_id_vars()
+ reference = self._match(TokenType.REFERENCES) and self.expression(
+ exp.Reference,
+ this=self._parse_id_var(),
+ expressions=self._parse_wrapped_id_vars(),
+ )
+ options = {}
+
+ while self._match(TokenType.ON):
+ if not self._match_set((TokenType.DELETE, TokenType.UPDATE)):
+ self.raise_error("Expected DELETE or UPDATE")
+ kind = self._prev.text.lower()
+
+ if self._match(TokenType.NO_ACTION):
+ action = "NO ACTION"
+ elif self._match(TokenType.SET):
+ self._match_set((TokenType.NULL, TokenType.DEFAULT))
+ action = "SET " + self._prev.text.upper()
+ else:
+ self._advance()
+ action = self._prev.text.upper()
+ options[kind] = action
+
+ return self.expression(
+ exp.ForeignKey,
+ expressions=expressions,
+ reference=reference,
+ **options,
+ )
+
+ def _parse_bracket(self, this):
+ if not self._match(TokenType.L_BRACKET):
+ return this
+
+ expressions = self._parse_csv(self._parse_conjunction)
+
+ if not this or this.name.upper() == "ARRAY":
+ this = self.expression(exp.Array, expressions=expressions)
+ else:
+ expressions = apply_index_offset(expressions, -self.index_offset)
+ this = self.expression(exp.Bracket, this=this, expressions=expressions)
+
+ if not self._match(TokenType.R_BRACKET):
+ self.raise_error("Expected ]")
+
+ return self._parse_bracket(this)
+
+ def _parse_case(self):
+ ifs = []
+ default = None
+
+ expression = self._parse_conjunction()
+
+ while self._match(TokenType.WHEN):
+ this = self._parse_conjunction()
+ self._match(TokenType.THEN)
+ then = self._parse_conjunction()
+ ifs.append(self.expression(exp.If, this=this, true=then))
+
+ if self._match(TokenType.ELSE):
+ default = self._parse_conjunction()
+
+ if not self._match(TokenType.END):
+ self.raise_error("Expected END after CASE", self._prev)
+
+ return self._parse_window(
+ self.expression(exp.Case, this=expression, ifs=ifs, default=default)
+ )
+
+ def _parse_if(self):
+ if self._match(TokenType.L_PAREN):
+ args = self._parse_csv(self._parse_conjunction)
+ this = exp.If.from_arg_list(args)
+ self.validate_expression(this, args)
+ self._match_r_paren()
+ else:
+ condition = self._parse_conjunction()
+ self._match(TokenType.THEN)
+ true = self._parse_conjunction()
+ false = self._parse_conjunction() if self._match(TokenType.ELSE) else None
+ self._match(TokenType.END)
+ this = self.expression(exp.If, this=condition, true=true, false=false)
+ return self._parse_window(this)
+
+ def _parse_extract(self):
+ this = self._parse_var() or self._parse_type()
+
+ if not self._match(TokenType.FROM):
+ self.raise_error("Expected FROM after EXTRACT", self._prev)
+
+ return self.expression(exp.Extract, this=this, expression=self._parse_type())
+
+ def _parse_cast(self, strict):
+ this = self._parse_conjunction()
+
+ if not self._match(TokenType.ALIAS):
+ self.raise_error("Expected AS after CAST")
+
+ to = self._parse_types()
+
+ if not to:
+ self.raise_error("Expected TYPE after CAST")
+ elif to.this == exp.DataType.Type.CHAR:
+ if self._match(TokenType.CHARACTER_SET):
+ to = self.expression(exp.CharacterSet, this=self._parse_var_or_string())
+
+ return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+
+ def _parse_convert(self):
+ this = self._parse_field()
+ if self._match(TokenType.USING):
+ to = self.expression(exp.CharacterSet, this=self._parse_var())
+ elif self._match(TokenType.COMMA):
+ to = self._parse_types()
+ else:
+ to = None
+ return self.expression(exp.Cast, this=this, to=to)
+
+ def _parse_window(self, this, alias=False):
+ if self._match(TokenType.FILTER):
+ self._match_l_paren()
+ this = self.expression(
+ exp.Filter, this=this, expression=self._parse_where()
+ )
+ self._match_r_paren()
+
+ if self._match(TokenType.WITHIN_GROUP):
+ self._match_l_paren()
+ this = self.expression(
+ exp.WithinGroup,
+ this=this,
+ expression=self._parse_order(),
+ )
+ self._match_r_paren()
+ return this
+
+ # bigquery select from window x AS (partition by ...)
+ if alias:
+ self._match(TokenType.ALIAS)
+ elif not self._match(TokenType.OVER):
+ return this
+
+ if not self._match(TokenType.L_PAREN):
+ alias = self._parse_id_var(False)
+
+ return self.expression(
+ exp.Window,
+ this=this,
+ alias=alias,
+ )
+
+ partition = None
+
+ alias = self._parse_id_var(False)
+
+ if self._match(TokenType.PARTITION_BY):
+ partition = self._parse_csv(self._parse_conjunction)
+
+ order = self._parse_order()
+
+ spec = None
+ kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text
+
+ if kind:
+ self._match(TokenType.BETWEEN)
+ start = self._parse_window_spec()
+ self._match(TokenType.AND)
+ end = self._parse_window_spec()
+
+ spec = self.expression(
+ exp.WindowSpec,
+ kind=kind,
+ start=start["value"],
+ start_side=start["side"],
+ end=end["value"],
+ end_side=end["side"],
+ )
+
+ self._match_r_paren()
+
+ return self.expression(
+ exp.Window,
+ this=this,
+ partition_by=partition,
+ order=order,
+ spec=spec,
+ alias=alias,
+ )
+
+ def _parse_window_spec(self):
+ self._match(TokenType.BETWEEN)
+
+ return {
+ "value": (
+ self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW))
+ and self._prev.text
+ )
+ or self._parse_bitwise(),
+ "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING))
+ and self._prev.text,
+ }
+
+ def _parse_alias(self, this, explicit=False):
+ any_token = self._match(TokenType.ALIAS)
+
+ if explicit and not any_token:
+ return this
+
+ if self._match(TokenType.L_PAREN):
+ aliases = self.expression(
+ exp.Aliases,
+ this=this,
+ expressions=self._parse_csv(lambda: self._parse_id_var(any_token)),
+ )
+ self._match_r_paren()
+ return aliases
+
+ alias = self._parse_id_var(any_token)
+
+ if alias:
+ return self.expression(exp.Alias, this=this, alias=alias)
+
+ return this
+
+ def _parse_id_var(self, any_token=True):
+ identifier = self._parse_identifier()
+
+ if identifier:
+ return identifier
+
+ if (
+ any_token
+ and self._curr
+ and self._curr.token_type not in self.RESERVED_KEYWORDS
+ ):
+ return self._advance() or exp.Identifier(this=self._prev.text, quoted=False)
+
+ return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier(
+ this=self._prev.text, quoted=False
+ )
+
+ def _parse_string(self):
+ if self._match(TokenType.STRING):
+ return exp.Literal.string(self._prev.text)
+ return self._parse_placeholder()
+
+ def _parse_number(self):
+ if self._match(TokenType.NUMBER):
+ return exp.Literal.number(self._prev.text)
+ return self._parse_placeholder()
+
+ def _parse_identifier(self):
+ if self._match(TokenType.IDENTIFIER):
+ return exp.Identifier(this=self._prev.text, quoted=True)
+ return self._parse_placeholder()
+
+ def _parse_var(self):
+ if self._match(TokenType.VAR):
+ return exp.Var(this=self._prev.text)
+ return self._parse_placeholder()
+
+ def _parse_var_or_string(self):
+ return self._parse_var() or self._parse_string()
+
+ def _parse_null(self):
+ if self._match(TokenType.NULL):
+ return exp.Null()
+ return None
+
+ def _parse_boolean(self):
+ if self._match(TokenType.TRUE):
+ return exp.Boolean(this=True)
+ if self._match(TokenType.FALSE):
+ return exp.Boolean(this=False)
+ return None
+
+ def _parse_star(self):
+ if self._match(TokenType.STAR):
+ return exp.Star(
+ **{"except": self._parse_except(), "replace": self._parse_replace()}
+ )
+ return None
+
+ def _parse_placeholder(self):
+ if self._match(TokenType.PLACEHOLDER):
+ return exp.Placeholder()
+ return None
+
+ def _parse_except(self):
+ if not self._match(TokenType.EXCEPT):
+ return None
+
+ return self._parse_wrapped_id_vars()
+
+ def _parse_replace(self):
+ if not self._match(TokenType.REPLACE):
+ return None
+
+ self._match_l_paren()
+ columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression()))
+ self._match_r_paren()
+ return columns
+
+ def _parse_csv(self, parse):
+ parse_result = parse()
+ items = [parse_result] if parse_result is not None else []
+
+ while self._match(TokenType.COMMA):
+ parse_result = parse()
+ if parse_result is not None:
+ items.append(parse_result)
+
+ return items
+
+ def _parse_tokens(self, parse, expressions):
+ this = parse()
+
+ while self._match_set(expressions):
+ this = self.expression(
+ expressions[self._prev.token_type], this=this, expression=parse()
+ )
+
+ return this
+
+ def _parse_all(self, parse):
+ return list(iter(parse, None))
+
+ def _parse_wrapped_id_vars(self):
+ self._match_l_paren()
+ expressions = self._parse_csv(self._parse_id_var)
+ self._match_r_paren()
+ return expressions
+
+ def _match(self, token_type):
+ if not self._curr:
+ return None
+
+ if self._curr.token_type == token_type:
+ self._advance()
+ return True
+
+ return None
+
+ def _match_set(self, types):
+ if not self._curr:
+ return None
+
+ if self._curr.token_type in types:
+ self._advance()
+ return True
+
+ return None
+
+ def _match_pair(self, token_type_a, token_type_b, advance=True):
+ if not self._curr or not self._next:
+ return None
+
+ if (
+ self._curr.token_type == token_type_a
+ and self._next.token_type == token_type_b
+ ):
+ if advance:
+ self._advance(2)
+ return True
+
+ return None
+
+ def _match_l_paren(self):
+ if not self._match(TokenType.L_PAREN):
+ self.raise_error("Expecting (")
+
+ def _match_r_paren(self):
+ if not self._match(TokenType.R_PAREN):
+ self.raise_error("Expecting )")
+
+ def _replace_columns_with_dots(self, this):
+ if isinstance(this, exp.Dot):
+ exp.replace_children(this, self._replace_columns_with_dots)
+ elif isinstance(this, exp.Column):
+ exp.replace_children(this, self._replace_columns_with_dots)
+ table = this.args.get("table")
+ this = (
+ self.expression(exp.Dot, this=table, expression=this.this)
+ if table
+ else self.expression(exp.Var, this=this.name)
+ )
+ elif isinstance(this, exp.Identifier):
+ this = self.expression(exp.Var, this=this.name)
+ return this
diff --git a/sqlglot/planner.py b/sqlglot/planner.py
new file mode 100644
index 0000000..2006a75
--- /dev/null
+++ b/sqlglot/planner.py
@@ -0,0 +1,340 @@
+import itertools
+import math
+
+from sqlglot import alias, exp
+from sqlglot.errors import UnsupportedError
+from sqlglot.optimizer.simplify import simplify
+
+
+class Plan:
+ def __init__(self, expression):
+ self.expression = expression
+ self.root = Step.from_expression(self.expression)
+ self._dag = {}
+
+ @property
+ def dag(self):
+ if not self._dag:
+ dag = {}
+ nodes = {self.root}
+
+ while nodes:
+ node = nodes.pop()
+ dag[node] = set()
+ for dep in node.dependencies:
+ dag[node].add(dep)
+ nodes.add(dep)
+ self._dag = dag
+
+ return self._dag
+
+ @property
+ def leaves(self):
+ return (node for node, deps in self.dag.items() if not deps)
+
+
+class Step:
+ @classmethod
+ def from_expression(cls, expression, ctes=None):
+ """
+ Build a DAG of Steps from a SQL expression.
+
+ Giving an expression like:
+
+ SELECT x.a, SUM(x.b)
+ FROM x
+ JOIN y
+ ON x.a = y.a
+ GROUP BY x.a
+
+ Transform it into a DAG of the form:
+
+ Aggregate(x.a, SUM(x.b))
+ Join(y)
+ Scan(x)
+ Scan(y)
+
+ This can then more easily be executed on by an engine.
+ """
+ ctes = ctes or {}
+ with_ = expression.args.get("with")
+
+ # CTEs break the mold of scope and introduce themselves to all in the context.
+ if with_:
+ ctes = ctes.copy()
+ for cte in with_.expressions:
+ step = Step.from_expression(cte.this, ctes)
+ step.name = cte.alias
+ ctes[step.name] = step
+
+ from_ = expression.args.get("from")
+
+ if from_:
+ from_ = from_.expressions
+ if len(from_) > 1:
+ raise UnsupportedError(
+ "Multi-from statements are unsupported. Run it through the optimizer"
+ )
+
+ step = Scan.from_expression(from_[0], ctes)
+ else:
+ raise UnsupportedError("Static selects are unsupported.")
+
+ joins = expression.args.get("joins")
+
+ if joins:
+ join = Join.from_joins(joins, ctes)
+ join.name = step.name
+ join.add_dependency(step)
+ step = join
+
+ projections = [] # final selects in this chain of steps representing a select
+ operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1)
+ aggregations = []
+ sequence = itertools.count()
+
+ for e in expression.expressions:
+ aggregation = e.find(exp.AggFunc)
+
+ if aggregation:
+ projections.append(exp.column(e.alias_or_name, step.name, quoted=True))
+ aggregations.append(e)
+ for operand in aggregation.unnest_operands():
+ if isinstance(operand, exp.Column):
+ continue
+ if operand not in operands:
+ operands[operand] = f"_a_{next(sequence)}"
+ operand.replace(
+ exp.column(operands[operand], step.name, quoted=True)
+ )
+ else:
+ projections.append(e)
+
+ where = expression.args.get("where")
+
+ if where:
+ step.condition = where.this
+
+ group = expression.args.get("group")
+
+ if group:
+ aggregate = Aggregate()
+ aggregate.source = step.name
+ aggregate.name = step.name
+ aggregate.operands = tuple(
+ alias(operand, alias_) for operand, alias_ in operands.items()
+ )
+ aggregate.aggregations = aggregations
+ aggregate.group = [
+ exp.column(e.alias_or_name, step.name, quoted=True)
+ for e in group.expressions
+ ]
+ aggregate.add_dependency(step)
+ step = aggregate
+
+ having = expression.args.get("having")
+
+ if having:
+ step.condition = having.this
+
+ order = expression.args.get("order")
+
+ if order:
+ sort = Sort()
+ sort.name = step.name
+ sort.key = order.expressions
+ sort.add_dependency(step)
+ step = sort
+ for k in sort.key + projections:
+ for column in k.find_all(exp.Column):
+ column.set("table", exp.to_identifier(step.name, quoted=True))
+
+ step.projections = projections
+
+ limit = expression.args.get("limit")
+
+ if limit:
+ step.limit = int(limit.text("expression"))
+
+ return step
+
+ def __init__(self):
+ self.name = None
+ self.dependencies = set()
+ self.dependents = set()
+ self.projections = []
+ self.limit = math.inf
+ self.condition = None
+
+ def add_dependency(self, dependency):
+ self.dependencies.add(dependency)
+ dependency.dependents.add(self)
+
+ def __repr__(self):
+ return self.to_s()
+
+ def to_s(self, level=0):
+ indent = " " * level
+ nested = f"{indent} "
+
+ context = self._to_s(f"{nested} ")
+
+ if context:
+ context = [f"{nested}Context:"] + context
+
+ lines = [
+ f"{indent}- {self.__class__.__name__}: {self.name}",
+ *context,
+ f"{nested}Projections:",
+ ]
+
+ for expression in self.projections:
+ lines.append(f"{nested} - {expression.sql()}")
+
+ if self.condition:
+ lines.append(f"{nested}Condition: {self.condition.sql()}")
+
+ if self.dependencies:
+ lines.append(f"{nested}Dependencies:")
+ for dependency in self.dependencies:
+ lines.append(" " + dependency.to_s(level + 1))
+
+ return "\n".join(lines)
+
+ def _to_s(self, _indent):
+ return []
+
+
+class Scan(Step):
+ @classmethod
+ def from_expression(cls, expression, ctes=None):
+ table = expression.this
+ alias_ = expression.alias
+
+ if not alias_:
+ raise UnsupportedError(
+ "Tables/Subqueries must be aliased. Run it through the optimizer"
+ )
+
+ if isinstance(expression, exp.Subquery):
+ step = Step.from_expression(table, ctes)
+ step.name = alias_
+ return step
+
+ step = Scan()
+ step.name = alias_
+ step.source = expression
+ if table.name in ctes:
+ step.add_dependency(ctes[table.name])
+
+ return step
+
+ def __init__(self):
+ super().__init__()
+ self.source = None
+
+ def _to_s(self, indent):
+ return [f"{indent}Source: {self.source.sql()}"]
+
+
+class Write(Step):
+ pass
+
+
+class Join(Step):
+ @classmethod
+ def from_joins(cls, joins, ctes=None):
+ step = Join()
+
+ for join in joins:
+ name = join.this.alias
+ on = join.args.get("on") or exp.TRUE
+ source_key = []
+ join_key = []
+
+ # find the join keys
+ # SELECT
+ # FROM x
+ # JOIN y
+ # ON x.a = y.b AND y.b > 1
+ #
+ # should pull y.b as the join key and x.a as the source key
+ for condition in on.flatten() if isinstance(on, exp.And) else [on]:
+ if isinstance(condition, exp.EQ):
+ left, right = condition.unnest_operands()
+ left_tables = exp.column_table_names(left)
+ right_tables = exp.column_table_names(right)
+
+ if name in left_tables and name not in right_tables:
+ join_key.append(left)
+ source_key.append(right)
+ condition.replace(exp.TRUE)
+ elif name in right_tables and name not in left_tables:
+ join_key.append(right)
+ source_key.append(left)
+ condition.replace(exp.TRUE)
+
+ on = simplify(on)
+
+ step.joins[name] = {
+ "side": join.side,
+ "join_key": join_key,
+ "source_key": source_key,
+ "condition": None if on == exp.TRUE else on,
+ }
+
+ step.add_dependency(Scan.from_expression(join.this, ctes))
+
+ return step
+
+ def __init__(self):
+ super().__init__()
+ self.joins = {}
+
+ def _to_s(self, indent):
+ lines = []
+ for name, join in self.joins.items():
+ lines.append(f"{indent}{name}: {join['side']}")
+ if join.get("condition"):
+ lines.append(f"{indent}On: {join['condition'].sql()}")
+ return lines
+
+
+class Aggregate(Step):
+ def __init__(self):
+ super().__init__()
+ self.aggregations = []
+ self.operands = []
+ self.group = []
+ self.source = None
+
+ def _to_s(self, indent):
+ lines = [f"{indent}Aggregations:"]
+
+ for expression in self.aggregations:
+ lines.append(f"{indent} - {expression.sql()}")
+
+ if self.group:
+ lines.append(f"{indent}Group:")
+ for expression in self.group:
+ lines.append(f"{indent} - {expression.sql()}")
+ if self.operands:
+ lines.append(f"{indent}Operands:")
+ for expression in self.operands:
+ lines.append(f"{indent} - {expression.sql()}")
+
+ return lines
+
+
+class Sort(Step):
+ def __init__(self):
+ super().__init__()
+ self.key = None
+
+ def _to_s(self, indent):
+ lines = [f"{indent}Key:"]
+
+ for expression in self.key:
+ lines.append(f"{indent} - {expression.sql()}")
+
+ return lines
diff --git a/sqlglot/time.py b/sqlglot/time.py
new file mode 100644
index 0000000..16314c5
--- /dev/null
+++ b/sqlglot/time.py
@@ -0,0 +1,45 @@
+# the generic time format is based on python time.strftime
+# https://docs.python.org/3/library/time.html#time.strftime
+from sqlglot.trie import in_trie, new_trie
+
+
+def format_time(string, mapping, trie=None):
+ """
+ Converts a time string given a mapping.
+
+ Examples:
+ >>> format_time("%Y", {"%Y": "YYYY"})
+ 'YYYY'
+
+ mapping: Dictionary of time format to target time format
+ trie: Optional trie, can be passed in for performance
+ """
+ start = 0
+ end = 1
+ size = len(string)
+ trie = trie or new_trie(mapping)
+ current = trie
+ chunks = []
+ sym = None
+
+ while end <= size:
+ chars = string[start:end]
+ result, current = in_trie(current, chars[-1])
+
+ if result == 0:
+ if sym:
+ end -= 1
+ chars = sym
+ sym = None
+ start += len(chars)
+ chunks.append(chars)
+ current = trie
+ elif result == 2:
+ sym = chars
+
+ end += 1
+
+ if result and end > size:
+ chunks.append(chars)
+
+ return "".join(mapping.get(chars, chars) for chars in chunks)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
new file mode 100644
index 0000000..e4b754d
--- /dev/null
+++ b/sqlglot/tokens.py
@@ -0,0 +1,853 @@
+from enum import auto
+
+from sqlglot.helper import AutoName
+from sqlglot.trie import in_trie, new_trie
+
+
+class TokenType(AutoName):
+ L_PAREN = auto()
+ R_PAREN = auto()
+ L_BRACKET = auto()
+ R_BRACKET = auto()
+ L_BRACE = auto()
+ R_BRACE = auto()
+ COMMA = auto()
+ DOT = auto()
+ DASH = auto()
+ PLUS = auto()
+ COLON = auto()
+ DCOLON = auto()
+ SEMICOLON = auto()
+ STAR = auto()
+ SLASH = auto()
+ LT = auto()
+ LTE = auto()
+ GT = auto()
+ GTE = auto()
+ NOT = auto()
+ EQ = auto()
+ NEQ = auto()
+ AND = auto()
+ OR = auto()
+ AMP = auto()
+ DPIPE = auto()
+ PIPE = auto()
+ CARET = auto()
+ TILDA = auto()
+ ARROW = auto()
+ DARROW = auto()
+ HASH_ARROW = auto()
+ DHASH_ARROW = auto()
+ ANNOTATION = auto()
+ DOLLAR = auto()
+
+ SPACE = auto()
+ BREAK = auto()
+
+ STRING = auto()
+ NUMBER = auto()
+ IDENTIFIER = auto()
+ COLUMN = auto()
+ COLUMN_DEF = auto()
+ SCHEMA = auto()
+ TABLE = auto()
+ VAR = auto()
+ BIT_STRING = auto()
+
+ # types
+ BOOLEAN = auto()
+ TINYINT = auto()
+ SMALLINT = auto()
+ INT = auto()
+ BIGINT = auto()
+ FLOAT = auto()
+ DOUBLE = auto()
+ DECIMAL = auto()
+ CHAR = auto()
+ NCHAR = auto()
+ VARCHAR = auto()
+ NVARCHAR = auto()
+ TEXT = auto()
+ BINARY = auto()
+ BYTEA = auto()
+ JSON = auto()
+ TIMESTAMP = auto()
+ TIMESTAMPTZ = auto()
+ DATETIME = auto()
+ DATE = auto()
+ UUID = auto()
+ GEOGRAPHY = auto()
+ NULLABLE = auto()
+
+ # keywords
+ ADD_FILE = auto()
+ ALIAS = auto()
+ ALL = auto()
+ ALTER = auto()
+ ANALYZE = auto()
+ ANY = auto()
+ ARRAY = auto()
+ ASC = auto()
+ AT_TIME_ZONE = auto()
+ AUTO_INCREMENT = auto()
+ BEGIN = auto()
+ BETWEEN = auto()
+ BUCKET = auto()
+ CACHE = auto()
+ CALL = auto()
+ CASE = auto()
+ CAST = auto()
+ CHARACTER_SET = auto()
+ CHECK = auto()
+ CLUSTER_BY = auto()
+ COLLATE = auto()
+ COMMENT = auto()
+ COMMIT = auto()
+ CONSTRAINT = auto()
+ CONVERT = auto()
+ CREATE = auto()
+ CROSS = auto()
+ CUBE = auto()
+ CURRENT_DATE = auto()
+ CURRENT_DATETIME = auto()
+ CURRENT_ROW = auto()
+ CURRENT_TIME = auto()
+ CURRENT_TIMESTAMP = auto()
+ DIV = auto()
+ DEFAULT = auto()
+ DELETE = auto()
+ DESC = auto()
+ DISTINCT = auto()
+ DISTRIBUTE_BY = auto()
+ DROP = auto()
+ ELSE = auto()
+ END = auto()
+ ENGINE = auto()
+ ESCAPE = auto()
+ EXCEPT = auto()
+ EXISTS = auto()
+ EXPLAIN = auto()
+ EXTRACT = auto()
+ FALSE = auto()
+ FETCH = auto()
+ FILTER = auto()
+ FINAL = auto()
+ FIRST = auto()
+ FOLLOWING = auto()
+ FOREIGN_KEY = auto()
+ FORMAT = auto()
+ FULL = auto()
+ FUNCTION = auto()
+ FROM = auto()
+ GROUP_BY = auto()
+ GROUPING_SETS = auto()
+ HAVING = auto()
+ HINT = auto()
+ IF = auto()
+ IGNORE_NULLS = auto()
+ ILIKE = auto()
+ IN = auto()
+ INDEX = auto()
+ INNER = auto()
+ INSERT = auto()
+ INTERSECT = auto()
+ INTERVAL = auto()
+ INTO = auto()
+ INTRODUCER = auto()
+ IS = auto()
+ ISNULL = auto()
+ JOIN = auto()
+ LATERAL = auto()
+ LAZY = auto()
+ LEFT = auto()
+ LIKE = auto()
+ LIMIT = auto()
+ LOCATION = auto()
+ MAP = auto()
+ MOD = auto()
+ NEXT = auto()
+ NO_ACTION = auto()
+ NULL = auto()
+ NULLS_FIRST = auto()
+ NULLS_LAST = auto()
+ OFFSET = auto()
+ ON = auto()
+ ONLY = auto()
+ OPTIMIZE = auto()
+ OPTIONS = auto()
+ ORDER_BY = auto()
+ ORDERED = auto()
+ ORDINALITY = auto()
+ OUTER = auto()
+ OUT_OF = auto()
+ OVER = auto()
+ OVERWRITE = auto()
+ PARTITION = auto()
+ PARTITION_BY = auto()
+ PARTITIONED_BY = auto()
+ PERCENT = auto()
+ PLACEHOLDER = auto()
+ PRECEDING = auto()
+ PRIMARY_KEY = auto()
+ PROPERTIES = auto()
+ QUALIFY = auto()
+ QUOTE = auto()
+ RANGE = auto()
+ RECURSIVE = auto()
+ REPLACE = auto()
+ RESPECT_NULLS = auto()
+ REFERENCES = auto()
+ RIGHT = auto()
+ RLIKE = auto()
+ ROLLUP = auto()
+ ROW = auto()
+ ROWS = auto()
+ SCHEMA_COMMENT = auto()
+ SELECT = auto()
+ SET = auto()
+ SHOW = auto()
+ SOME = auto()
+ SORT_BY = auto()
+ STORED = auto()
+ STRUCT = auto()
+ TABLE_FORMAT = auto()
+ TABLE_SAMPLE = auto()
+ TEMPORARY = auto()
+ TIME = auto()
+ TOP = auto()
+ THEN = auto()
+ TRUE = auto()
+ TRUNCATE = auto()
+ TRY_CAST = auto()
+ UNBOUNDED = auto()
+ UNCACHE = auto()
+ UNION = auto()
+ UNNEST = auto()
+ UPDATE = auto()
+ USE = auto()
+ USING = auto()
+ VALUES = auto()
+ VIEW = auto()
+ WHEN = auto()
+ WHERE = auto()
+ WINDOW = auto()
+ WITH = auto()
+ WITH_TIME_ZONE = auto()
+ WITHIN_GROUP = auto()
+ WITHOUT_TIME_ZONE = auto()
+ UNIQUE = auto()
+
+
+class Token:
+ __slots__ = ("token_type", "text", "line", "col")
+
+ @classmethod
+ def number(cls, number):
+ return cls(TokenType.NUMBER, str(number))
+
+ @classmethod
+ def string(cls, string):
+ return cls(TokenType.STRING, string)
+
+ @classmethod
+ def identifier(cls, identifier):
+ return cls(TokenType.IDENTIFIER, identifier)
+
+ @classmethod
+ def var(cls, var):
+ return cls(TokenType.VAR, var)
+
+ def __init__(self, token_type, text, line=1, col=1):
+ self.token_type = token_type
+ self.text = text
+ self.line = line
+ self.col = max(col - len(text), 1)
+
+ def __repr__(self):
+ attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__)
+ return f"<Token {attributes}>"
+
+
+class _Tokenizer(type):
+ def __new__(cls, clsname, bases, attrs):
+ klass = super().__new__(cls, clsname, bases, attrs)
+
+ klass.QUOTES = dict(
+ (quote, quote) if isinstance(quote, str) else (quote[0], quote[1])
+ for quote in klass.QUOTES
+ )
+
+ klass.IDENTIFIERS = dict(
+ (identifier, identifier)
+ if isinstance(identifier, str)
+ else (identifier[0], identifier[1])
+ for identifier in klass.IDENTIFIERS
+ )
+
+ klass.COMMENTS = dict(
+ (comment, None) if isinstance(comment, str) else (comment[0], comment[1])
+ for comment in klass.COMMENTS
+ )
+
+ klass.KEYWORD_TRIE = new_trie(
+ key.upper()
+ for key, value in {
+ **klass.KEYWORDS,
+ **{comment: TokenType.COMMENT for comment in klass.COMMENTS},
+ **{quote: TokenType.QUOTE for quote in klass.QUOTES},
+ }.items()
+ if " " in key or any(single in key for single in klass.SINGLE_TOKENS)
+ )
+
+ return klass
+
+
+class Tokenizer(metaclass=_Tokenizer):
+ SINGLE_TOKENS = {
+ "(": TokenType.L_PAREN,
+ ")": TokenType.R_PAREN,
+ "[": TokenType.L_BRACKET,
+ "]": TokenType.R_BRACKET,
+ "{": TokenType.L_BRACE,
+ "}": TokenType.R_BRACE,
+ "&": TokenType.AMP,
+ "^": TokenType.CARET,
+ ":": TokenType.COLON,
+ ",": TokenType.COMMA,
+ ".": TokenType.DOT,
+ "-": TokenType.DASH,
+ "=": TokenType.EQ,
+ ">": TokenType.GT,
+ "<": TokenType.LT,
+ "%": TokenType.MOD,
+ "!": TokenType.NOT,
+ "|": TokenType.PIPE,
+ "+": TokenType.PLUS,
+ ";": TokenType.SEMICOLON,
+ "/": TokenType.SLASH,
+ "*": TokenType.STAR,
+ "~": TokenType.TILDA,
+ "?": TokenType.PLACEHOLDER,
+ "#": TokenType.ANNOTATION,
+ "$": TokenType.DOLLAR,
+ # used for breaking a var like x'y' but nothing else
+ # the token type doesn't matter
+ "'": TokenType.QUOTE,
+ "`": TokenType.IDENTIFIER,
+ '"': TokenType.IDENTIFIER,
+ }
+
+ QUOTES = ["'"]
+
+ IDENTIFIERS = ['"']
+
+ ESCAPE = "'"
+
+ KEYWORDS = {
+ "/*+": TokenType.HINT,
+ "*/": TokenType.HINT,
+ "==": TokenType.EQ,
+ "::": TokenType.DCOLON,
+ "||": TokenType.DPIPE,
+ ">=": TokenType.GTE,
+ "<=": TokenType.LTE,
+ "<>": TokenType.NEQ,
+ "!=": TokenType.NEQ,
+ "->": TokenType.ARROW,
+ "->>": TokenType.DARROW,
+ "#>": TokenType.HASH_ARROW,
+ "#>>": TokenType.DHASH_ARROW,
+ "ADD ARCHIVE": TokenType.ADD_FILE,
+ "ADD ARCHIVES": TokenType.ADD_FILE,
+ "ADD FILE": TokenType.ADD_FILE,
+ "ADD FILES": TokenType.ADD_FILE,
+ "ADD JAR": TokenType.ADD_FILE,
+ "ADD JARS": TokenType.ADD_FILE,
+ "ALL": TokenType.ALL,
+ "ALTER": TokenType.ALTER,
+ "ANALYZE": TokenType.ANALYZE,
+ "AND": TokenType.AND,
+ "ANY": TokenType.ANY,
+ "ASC": TokenType.ASC,
+ "AS": TokenType.ALIAS,
+ "AT TIME ZONE": TokenType.AT_TIME_ZONE,
+ "AUTO_INCREMENT": TokenType.AUTO_INCREMENT,
+ "BEGIN": TokenType.BEGIN,
+ "BETWEEN": TokenType.BETWEEN,
+ "BUCKET": TokenType.BUCKET,
+ "CALL": TokenType.CALL,
+ "CACHE": TokenType.CACHE,
+ "UNCACHE": TokenType.UNCACHE,
+ "CASE": TokenType.CASE,
+ "CAST": TokenType.CAST,
+ "CHARACTER SET": TokenType.CHARACTER_SET,
+ "CHECK": TokenType.CHECK,
+ "CLUSTER BY": TokenType.CLUSTER_BY,
+ "COLLATE": TokenType.COLLATE,
+ "COMMENT": TokenType.SCHEMA_COMMENT,
+ "COMMIT": TokenType.COMMIT,
+ "CONSTRAINT": TokenType.CONSTRAINT,
+ "CONVERT": TokenType.CONVERT,
+ "CREATE": TokenType.CREATE,
+ "CROSS": TokenType.CROSS,
+ "CUBE": TokenType.CUBE,
+ "CURRENT_DATE": TokenType.CURRENT_DATE,
+ "CURRENT ROW": TokenType.CURRENT_ROW,
+ "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP,
+ "DIV": TokenType.DIV,
+ "DEFAULT": TokenType.DEFAULT,
+ "DELETE": TokenType.DELETE,
+ "DESC": TokenType.DESC,
+ "DISTINCT": TokenType.DISTINCT,
+ "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY,
+ "DROP": TokenType.DROP,
+ "ELSE": TokenType.ELSE,
+ "END": TokenType.END,
+ "ENGINE": TokenType.ENGINE,
+ "ESCAPE": TokenType.ESCAPE,
+ "EXCEPT": TokenType.EXCEPT,
+ "EXISTS": TokenType.EXISTS,
+ "EXPLAIN": TokenType.EXPLAIN,
+ "EXTRACT": TokenType.EXTRACT,
+ "FALSE": TokenType.FALSE,
+ "FETCH": TokenType.FETCH,
+ "FILTER": TokenType.FILTER,
+ "FIRST": TokenType.FIRST,
+ "FULL": TokenType.FULL,
+ "FUNCTION": TokenType.FUNCTION,
+ "FOLLOWING": TokenType.FOLLOWING,
+ "FOREIGN KEY": TokenType.FOREIGN_KEY,
+ "FORMAT": TokenType.FORMAT,
+ "FROM": TokenType.FROM,
+ "GROUP BY": TokenType.GROUP_BY,
+ "GROUPING SETS": TokenType.GROUPING_SETS,
+ "HAVING": TokenType.HAVING,
+ "IF": TokenType.IF,
+ "ILIKE": TokenType.ILIKE,
+ "IGNORE NULLS": TokenType.IGNORE_NULLS,
+ "IN": TokenType.IN,
+ "INDEX": TokenType.INDEX,
+ "INNER": TokenType.INNER,
+ "INSERT": TokenType.INSERT,
+ "INTERVAL": TokenType.INTERVAL,
+ "INTERSECT": TokenType.INTERSECT,
+ "INTO": TokenType.INTO,
+ "IS": TokenType.IS,
+ "ISNULL": TokenType.ISNULL,
+ "JOIN": TokenType.JOIN,
+ "LATERAL": TokenType.LATERAL,
+ "LAZY": TokenType.LAZY,
+ "LEFT": TokenType.LEFT,
+ "LIKE": TokenType.LIKE,
+ "LIMIT": TokenType.LIMIT,
+ "LOCATION": TokenType.LOCATION,
+ "NEXT": TokenType.NEXT,
+ "NO ACTION": TokenType.NO_ACTION,
+ "NOT": TokenType.NOT,
+ "NULL": TokenType.NULL,
+ "NULLS FIRST": TokenType.NULLS_FIRST,
+ "NULLS LAST": TokenType.NULLS_LAST,
+ "OFFSET": TokenType.OFFSET,
+ "ON": TokenType.ON,
+ "ONLY": TokenType.ONLY,
+ "OPTIMIZE": TokenType.OPTIMIZE,
+ "OPTIONS": TokenType.OPTIONS,
+ "OR": TokenType.OR,
+ "ORDER BY": TokenType.ORDER_BY,
+ "ORDINALITY": TokenType.ORDINALITY,
+ "OUTER": TokenType.OUTER,
+ "OUT OF": TokenType.OUT_OF,
+ "OVER": TokenType.OVER,
+ "OVERWRITE": TokenType.OVERWRITE,
+ "PARTITION": TokenType.PARTITION,
+ "PARTITION BY": TokenType.PARTITION_BY,
+ "PARTITIONED BY": TokenType.PARTITIONED_BY,
+ "PERCENT": TokenType.PERCENT,
+ "PRECEDING": TokenType.PRECEDING,
+ "PRIMARY KEY": TokenType.PRIMARY_KEY,
+ "RANGE": TokenType.RANGE,
+ "RECURSIVE": TokenType.RECURSIVE,
+ "REGEXP": TokenType.RLIKE,
+ "REPLACE": TokenType.REPLACE,
+ "RESPECT NULLS": TokenType.RESPECT_NULLS,
+ "REFERENCES": TokenType.REFERENCES,
+ "RIGHT": TokenType.RIGHT,
+ "RLIKE": TokenType.RLIKE,
+ "ROLLUP": TokenType.ROLLUP,
+ "ROW": TokenType.ROW,
+ "ROWS": TokenType.ROWS,
+ "SELECT": TokenType.SELECT,
+ "SET": TokenType.SET,
+ "SHOW": TokenType.SHOW,
+ "SOME": TokenType.SOME,
+ "SORT BY": TokenType.SORT_BY,
+ "STORED": TokenType.STORED,
+ "TABLE": TokenType.TABLE,
+ "TABLE_FORMAT": TokenType.TABLE_FORMAT,
+ "TBLPROPERTIES": TokenType.PROPERTIES,
+ "TABLESAMPLE": TokenType.TABLE_SAMPLE,
+ "TEMP": TokenType.TEMPORARY,
+ "TEMPORARY": TokenType.TEMPORARY,
+ "THEN": TokenType.THEN,
+ "TRUE": TokenType.TRUE,
+ "TRUNCATE": TokenType.TRUNCATE,
+ "TRY_CAST": TokenType.TRY_CAST,
+ "UNBOUNDED": TokenType.UNBOUNDED,
+ "UNION": TokenType.UNION,
+ "UNNEST": TokenType.UNNEST,
+ "UPDATE": TokenType.UPDATE,
+ "USE": TokenType.USE,
+ "USING": TokenType.USING,
+ "VALUES": TokenType.VALUES,
+ "VIEW": TokenType.VIEW,
+ "WHEN": TokenType.WHEN,
+ "WHERE": TokenType.WHERE,
+ "WITH": TokenType.WITH,
+ "WITH TIME ZONE": TokenType.WITH_TIME_ZONE,
+ "WITHIN GROUP": TokenType.WITHIN_GROUP,
+ "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE,
+ "ARRAY": TokenType.ARRAY,
+ "BOOL": TokenType.BOOLEAN,
+ "BOOLEAN": TokenType.BOOLEAN,
+ "BYTE": TokenType.TINYINT,
+ "TINYINT": TokenType.TINYINT,
+ "SHORT": TokenType.SMALLINT,
+ "SMALLINT": TokenType.SMALLINT,
+ "INT2": TokenType.SMALLINT,
+ "INTEGER": TokenType.INT,
+ "INT": TokenType.INT,
+ "INT4": TokenType.INT,
+ "LONG": TokenType.BIGINT,
+ "BIGINT": TokenType.BIGINT,
+ "INT8": TokenType.BIGINT,
+ "DECIMAL": TokenType.DECIMAL,
+ "MAP": TokenType.MAP,
+ "NUMBER": TokenType.DECIMAL,
+ "NUMERIC": TokenType.DECIMAL,
+ "FIXED": TokenType.DECIMAL,
+ "REAL": TokenType.FLOAT,
+ "FLOAT": TokenType.FLOAT,
+ "FLOAT4": TokenType.FLOAT,
+ "FLOAT8": TokenType.DOUBLE,
+ "DOUBLE": TokenType.DOUBLE,
+ "JSON": TokenType.JSON,
+ "CHAR": TokenType.CHAR,
+ "NCHAR": TokenType.NCHAR,
+ "VARCHAR": TokenType.VARCHAR,
+ "VARCHAR2": TokenType.VARCHAR,
+ "NVARCHAR": TokenType.NVARCHAR,
+ "NVARCHAR2": TokenType.NVARCHAR,
+ "STRING": TokenType.TEXT,
+ "TEXT": TokenType.TEXT,
+ "CLOB": TokenType.TEXT,
+ "BINARY": TokenType.BINARY,
+ "BLOB": TokenType.BINARY,
+ "BYTEA": TokenType.BINARY,
+ "TIMESTAMP": TokenType.TIMESTAMP,
+ "TIMESTAMPTZ": TokenType.TIMESTAMPTZ,
+ "DATE": TokenType.DATE,
+ "DATETIME": TokenType.DATETIME,
+ "UNIQUE": TokenType.UNIQUE,
+ "STRUCT": TokenType.STRUCT,
+ }
+
+ WHITE_SPACE = {
+ " ": TokenType.SPACE,
+ "\t": TokenType.SPACE,
+ "\n": TokenType.BREAK,
+ "\r": TokenType.BREAK,
+ "\r\n": TokenType.BREAK,
+ }
+
+ COMMANDS = {
+ TokenType.ALTER,
+ TokenType.ADD_FILE,
+ TokenType.ANALYZE,
+ TokenType.BEGIN,
+ TokenType.CALL,
+ TokenType.COMMIT,
+ TokenType.EXPLAIN,
+ TokenType.OPTIMIZE,
+ TokenType.SET,
+ TokenType.SHOW,
+ TokenType.TRUNCATE,
+ TokenType.USE,
+ }
+
+ # handle numeric literals like in hive (3L = BIGINT)
+ NUMERIC_LITERALS = {}
+ ENCODE = None
+
+ COMMENTS = ["--", ("/*", "*/")]
+ KEYWORD_TRIE = None # autofilled
+
+ __slots__ = (
+ "sql",
+ "size",
+ "tokens",
+ "_start",
+ "_current",
+ "_line",
+ "_col",
+ "_char",
+ "_end",
+ "_peek",
+ )
+
+ def __init__(self):
+ """
+ Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token`
+ """
+ self.reset()
+
+ def reset(self):
+ self.sql = ""
+ self.size = 0
+ self.tokens = []
+ self._start = 0
+ self._current = 0
+ self._line = 1
+ self._col = 1
+
+ self._char = None
+ self._end = None
+ self._peek = None
+
+ def tokenize(self, sql):
+ self.reset()
+ self.sql = sql
+ self.size = len(sql)
+
+ while self.size and not self._end:
+ self._start = self._current
+ self._advance()
+
+ if not self._char:
+ break
+
+ white_space = self.WHITE_SPACE.get(self._char)
+ identifier_end = self.IDENTIFIERS.get(self._char)
+
+ if white_space:
+ if white_space == TokenType.BREAK:
+ self._col = 1
+ self._line += 1
+ elif self._char == "0" and self._peek == "x":
+ self._scan_hex()
+ elif self._char.isdigit():
+ self._scan_number()
+ elif identifier_end:
+ self._scan_identifier(identifier_end)
+ else:
+ self._scan_keywords()
+ return self.tokens
+
+ def _chars(self, size):
+ if size == 1:
+ return self._char
+ start = self._current - 1
+ end = start + size
+ if end <= self.size:
+ return self.sql[start:end]
+ return ""
+
+ def _advance(self, i=1):
+ self._col += i
+ self._current += i
+ self._end = self._current >= self.size
+ self._char = self.sql[self._current - 1]
+ self._peek = self.sql[self._current] if self._current < self.size else ""
+
+ @property
+ def _text(self):
+ return self.sql[self._start : self._current]
+
+ def _add(self, token_type, text=None):
+ text = self._text if text is None else text
+ self.tokens.append(Token(token_type, text, self._line, self._col))
+
+ if token_type in self.COMMANDS and (
+ len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
+ ):
+ self._start = self._current
+ while not self._end and self._peek != ";":
+ self._advance()
+ if self._start < self._current:
+ self._add(TokenType.STRING)
+
+ def _scan_keywords(self):
+ size = 0
+ word = None
+ chars = self._text
+ char = chars
+ prev_space = False
+ skip = False
+ trie = self.KEYWORD_TRIE
+
+ while chars:
+ if skip:
+ result = 1
+ else:
+ result, trie = in_trie(trie, char.upper())
+
+ if result == 0:
+ break
+ if result == 2:
+ word = chars
+ size += 1
+ end = self._current - 1 + size
+
+ if end < self.size:
+ char = self.sql[end]
+ is_space = char in self.WHITE_SPACE
+
+ if not is_space or not prev_space:
+ if is_space:
+ char = " "
+ chars += char
+ prev_space = is_space
+ skip = False
+ else:
+ skip = True
+ else:
+ chars = None
+
+ if not word:
+ if self._char in self.SINGLE_TOKENS:
+ token = self.SINGLE_TOKENS[self._char]
+ if token == TokenType.ANNOTATION:
+ self._scan_annotation()
+ return
+ self._add(token)
+ return
+ self._scan_var()
+ return
+
+ if self._scan_string(word):
+ return
+ if self._scan_comment(word):
+ return
+
+ self._advance(size - 1)
+ self._add(self.KEYWORDS[word.upper()])
+
+ def _scan_comment(self, comment_start):
+ if comment_start not in self.COMMENTS:
+ return False
+
+ comment_end = self.COMMENTS[comment_start]
+
+ if comment_end:
+ comment_end_size = len(comment_end)
+
+ while not self._end and self._chars(comment_end_size) != comment_end:
+ self._advance()
+ self._advance(comment_end_size - 1)
+ else:
+ while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK:
+ self._advance()
+ return True
+
+ def _scan_annotation(self):
+ while (
+ not self._end
+ and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK
+ and self._peek != ","
+ ):
+ self._advance()
+ self._add(TokenType.ANNOTATION, self._text[1:])
+
+ def _scan_number(self):
+ decimal = False
+ scientific = 0
+
+ while True:
+ if self._peek.isdigit():
+ self._advance()
+ elif self._peek == "." and not decimal:
+ decimal = True
+ self._advance()
+ elif self._peek in ("-", "+") and scientific == 1:
+ scientific += 1
+ self._advance()
+ elif self._peek.upper() == "E" and not scientific:
+ scientific += 1
+ self._advance()
+ elif self._peek.isalpha():
+ self._add(TokenType.NUMBER)
+ literal = []
+ while self._peek.isalpha():
+ literal.append(self._peek.upper())
+ self._advance()
+ literal = "".join(literal)
+ token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal))
+ if token_type:
+ self._add(TokenType.DCOLON, "::")
+ return self._add(token_type, literal)
+ return self._advance(-len(literal))
+ else:
+ return self._add(TokenType.NUMBER)
+
+ def _scan_hex(self):
+ self._advance()
+
+ while True:
+ char = self._peek.strip()
+ if char and char not in self.SINGLE_TOKENS:
+ self._advance()
+ else:
+ break
+ try:
+ self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}")
+ except ValueError:
+ self._add(TokenType.IDENTIFIER)
+
+ def _scan_string(self, quote):
+ quote_end = self.QUOTES.get(quote)
+ if quote_end is None:
+ return False
+
+ text = ""
+ self._advance(len(quote))
+ quote_end_size = len(quote_end)
+
+ while True:
+ if self._char == self.ESCAPE and self._peek == quote_end:
+ text += quote
+ self._advance(2)
+ else:
+ if self._chars(quote_end_size) == quote_end:
+ if quote_end_size > 1:
+ self._advance(quote_end_size - 1)
+ break
+
+ if self._end:
+ raise RuntimeError(
+ f"Missing {quote} from {self._line}:{self._start}"
+ )
+ text += self._char
+ self._advance()
+
+ text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text
+ text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text
+ self._add(TokenType.STRING, text)
+ return True
+
+ def _scan_identifier(self, identifier_end):
+ while self._peek != identifier_end:
+ if self._end:
+ raise RuntimeError(
+ f"Missing {identifier_end} from {self._line}:{self._start}"
+ )
+ self._advance()
+ self._advance()
+ self._add(TokenType.IDENTIFIER, self._text[1:-1])
+
+ def _scan_var(self):
+ while True:
+ char = self._peek.strip()
+ if char and char not in self.SINGLE_TOKENS:
+ self._advance()
+ else:
+ break
+ self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR))
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
new file mode 100644
index 0000000..e7ccb8e
--- /dev/null
+++ b/sqlglot/transforms.py
@@ -0,0 +1,68 @@
+from sqlglot import expressions as exp
+
+
+def unalias_group(expression):
+ """
+ Replace references to select aliases in GROUP BY clauses.
+
+ Example:
+ >>> import sqlglot
+ >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql()
+ 'SELECT a AS b FROM x GROUP BY 1'
+ """
+ if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select):
+ aliased_selects = {
+ e.alias: i
+ for i, e in enumerate(expression.parent.expressions, start=1)
+ if isinstance(e, exp.Alias)
+ }
+
+ expression = expression.copy()
+
+ for col in expression.find_all(exp.Column):
+ alias_index = aliased_selects.get(col.name)
+ if not col.table and alias_index:
+ col.replace(exp.Literal.number(alias_index))
+
+ return expression
+
+
+def preprocess(transforms, to_sql):
+ """
+ Create a new transform function that can be used a value in `Generator.TRANSFORMS`
+ to convert expressions to SQL.
+
+ Args:
+ transforms (list[(exp.Expression) -> exp.Expression]):
+ Sequence of transform functions. These will be called in order.
+ to_sql ((sqlglot.generator.Generator, exp.Expression) -> str):
+ Final transform that converts the resulting expression to a SQL string.
+ Returns:
+ (sqlglot.generator.Generator, exp.Expression) -> str:
+ Function that can be used as a generator transform.
+ """
+
+ def _to_sql(self, expression):
+ expression = transforms[0](expression)
+ for t in transforms[1:]:
+ expression = t(expression)
+ return to_sql(self, expression)
+
+ return _to_sql
+
+
+def delegate(attr):
+ """
+ Create a new method that delegates to `attr`.
+
+ This is useful for creating `Generator.TRANSFORMS` functions that delegate
+ to existing generator methods.
+ """
+
+ def _transform(self, *args, **kwargs):
+ return getattr(self, attr)(*args, **kwargs)
+
+ return _transform
+
+
+UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))}
diff --git a/sqlglot/trie.py b/sqlglot/trie.py
new file mode 100644
index 0000000..a234107
--- /dev/null
+++ b/sqlglot/trie.py
@@ -0,0 +1,27 @@
+def new_trie(keywords):
+ trie = {}
+
+ for key in keywords:
+ current = trie
+
+ for char in key:
+ current = current.setdefault(char, {})
+ current[0] = True
+
+ return trie
+
+
+def in_trie(trie, key):
+ if not key:
+ return (0, trie)
+
+ current = trie
+
+ for char in key:
+ if char not in current:
+ return (0, current)
+ current = current[char]
+
+ if 0 in current:
+ return (2, current)
+ return (1, current)