diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-15 16:46:17 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2022-09-15 16:46:17 +0000 |
commit | 28cc22419e32a65fea2d1678400265b8cabc3aff (patch) | |
tree | ff9ac1991fd48490b21ef6aa9015a347a165e2d9 | |
parent | Initial commit. (diff) | |
download | sqlglot-28cc22419e32a65fea2d1678400265b8cabc3aff.tar.xz sqlglot-28cc22419e32a65fea2d1678400265b8cabc3aff.zip |
Adding upstream version 6.0.4.upstream/6.0.4
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
122 files changed, 23162 insertions, 0 deletions
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000..18ce8b6 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,26 @@ +name: Test and Lint Python Package + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.7", "3.8", "3.9", "3.10"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + - name: Run checks (linter, code style, tests) + run: | + ./run_checks.sh diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..061d863 --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,27 @@ +name: Publish Python Release to PyPI + +on: + push: + tags: + - "v*" + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.9' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine + - name: Build and publish + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python setup.py sdist bdist_wheel + twine upload dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bd6ad26 --- /dev/null +++ b/.gitignore @@ -0,0 +1,132 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# PyCharm +.idea/ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..500bc70 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.linting.pylintEnabled": true +}
\ No newline at end of file @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Toby Mao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..5ab4507 --- /dev/null +++ b/README.md @@ -0,0 +1,330 @@ +# SQLGlot + +SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. + +It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python. + +You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. + +Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. + +## Install +From PyPI + +``` +pip3 install sqlglot +``` + +Or with a local checkout + +``` +pip3 install -e . +``` + +## Examples +Easily translate from one dialect to another. For example, date/time functions vary from dialects and can be hard to deal with. + +```python +import sqlglot +sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive') +``` + +```sql +SELECT TO_UTC_TIMESTAMP(FROM_UNIXTIME(1618088028295 / 1000, 'yyyy-MM-dd HH:mm:ss'), 'UTC') +``` + +SQLGlot can even translate custom time formats. +```python +import sqlglot +sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read='duckdb', write='hive') +``` + +```sql +SELECT DATE_FORMAT(x, 'yy-M-ss')" +``` + +## Formatting and Transpiling +Read in a SQL statement with a CTE and CASTING to a REAL and then transpiling to Spark. + +Spark uses backticks as identifiers and the REAL type is transpiled to FLOAT. + +```python +import sqlglot + +sql = """WITH baz AS (SELECT a, c FROM foo WHERE a = 1) SELECT f.a, b.b, baz.c, CAST("b"."a" AS REAL) d FROM foo f JOIN bar b ON f.a = b.a LEFT JOIN baz ON f.a = baz.a""" +sqlglot.transpile(sql, write='spark', identify=True, pretty=True)[0] +``` + +```sql +WITH `baz` AS ( + SELECT + `a`, + `c` + FROM `foo` + WHERE + `a` = 1 +) +SELECT + `f`.`a`, + `b`.`b`, + `baz`.`c`, + CAST(`b`.`a` AS FLOAT) AS `d` +FROM `foo` AS `f` +JOIN `bar` AS `b` + ON `f`.`a` = `b`.`a` +LEFT JOIN `baz` + ON `f`.`a` = `baz`.`a` +``` + +## Metadata + +You can explore SQL with expression helpers to do things like find columns and tables. + +```python +from sqlglot import parse_one, exp + +# print all column references (a and b) +for column in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Column): + print(column.alias_or_name) + +# find all projections in select statements (a and c) +for select in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Select): + for projection in select.expressions: + print(projection.alias_or_name) + +# find all tables (x, y, z) +for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table): + print(table.name) +``` + +## Parser Errors +A syntax error will result in a parser error. +```python +transpile("SELECT foo( FROM bar") +``` + +sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13. + select foo( __FROM__ bar + +## Unsupported Errors +Presto APPROX_DISTINCT supports the accuracy argument which is not supported in Spark. + +```python +transpile( + 'SELECT APPROX_DISTINCT(a, 0.1) FROM foo', + read='presto', + write='spark', +) +``` + +```sql +WARNING:root:APPROX_COUNT_DISTINCT does not support accuracy + +SELECT APPROX_COUNT_DISTINCT(a) FROM foo +``` + +## Build and Modify SQL +SQLGlot supports incrementally building sql expressions. + +```python +from sqlglot import select, condition + +where = condition("x=1").and_("y=1") +select("*").from_("y").where(where).sql() +``` +Which outputs: +```sql +SELECT * FROM y WHERE x = 1 AND y = 1 +``` + +You can also modify a parsed tree: + +```python +from sqlglot import parse_one + +parse_one("SELECT x FROM y").from_("z").sql() +``` +Which outputs: +```sql +SELECT x FROM y, z +``` + +There is also a way to recursively transform the parsed tree by applying a mapping function to each tree node: + +```python +from sqlglot import exp, parse_one + +expression_tree = parse_one("SELECT a FROM x") + +def transformer(node): + if isinstance(node, exp.Column) and node.name == "a": + return parse_one("FUN(a)") + return node + +transformed_tree = expression_tree.transform(transformer) +transformed_tree.sql() +``` +Which outputs: +```sql +SELECT FUN(a) FROM x +``` + +## SQL Optimizer + +SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. + +```python +import sqlglot +from sqlglot.optimizer import optimize + +>>> +optimize( + sqlglot.parse_one(""" + SELECT A OR (B OR (C AND D)) + FROM x + WHERE Z = date '2021-01-01' + INTERVAL '1' month OR 1 = 0 + """), + schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}} +).sql(pretty=True) + +""" +SELECT + ( + "x"."A" + OR "x"."B" + OR "x"."C" + ) + AND ( + "x"."A" + OR "x"."B" + OR "x"."D" + ) AS "_col_0" +FROM "x" AS "x" +WHERE + "x"."Z" = CAST('2021-02-01' AS DATE) +""" +``` + +## SQL Annotations + +SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example: + +```sql +SELECT + user #primary_key, + country +FROM users +``` + +SQL annotations are currently incompatible with MySQL, which uses the `#` character to introduce comments. + +## AST Introspection + +You can see the AST version of the sql by calling repr. + +```python +from sqlglot import parse_one +repr(parse_one("SELECT a + 1 AS z")) + +(SELECT expressions: + (ALIAS this: + (ADD this: + (COLUMN this: + (IDENTIFIER this: a, quoted: False)), expression: + (LITERAL this: 1, is_string: False)), alias: + (IDENTIFIER this: z, quoted: False))) +``` + +## AST Diff + +SQLGlot can calculate the difference between two expressions and output changes in a form of a sequence of actions needed to transform a source expression into a target one. + +```python +from sqlglot import diff, parse_one +diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d")) + +[ + Remove(expression=(ADD this: + (COLUMN this: + (IDENTIFIER this: a, quoted: False)), expression: + (COLUMN this: + (IDENTIFIER this: b, quoted: False)))), + Insert(expression=(SUB this: + (COLUMN this: + (IDENTIFIER this: a, quoted: False)), expression: + (COLUMN this: + (IDENTIFIER this: b, quoted: False)))), + Move(expression=(COLUMN this: + (IDENTIFIER this: c, quoted: False))), + Keep(source=(IDENTIFIER this: b, quoted: False), target=(IDENTIFIER this: b, quoted: False)), + ... +] +``` + +## Custom Dialects + +[Dialects](sqlglot/dialects) can be added by subclassing Dialect. + +```python +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.tokens import Tokenizer, TokenType + + +class Custom(Dialect): + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + IDENTIFIERS = ["`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + } + + class Generator(Generator): + TRANSFORMS = {exp.Array: lambda self, e: f"[{self.expressions(e)}]"} + + TYPE_MAPPING = { + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + } + + +Dialects["custom"] +``` + +## Benchmarks + +[Benchmarks](benchmarks) run on Python 3.10.5 in seconds. + +| Query | sqlglot | sqltree | sqlparse | moz_sql_parser | sqloxide | +| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | +| tpch | 0.01178 (1.0) | 0.01173 (0.995) | 0.04676 (3.966) | 0.06800 (5.768) | 0.00094 (0.080) | +| short | 0.00084 (1.0) | 0.00079 (0.948) | 0.00296 (3.524) | 0.00443 (5.266) | 0.00006 (0.072) | +| long | 0.01102 (1.0) | 0.01044 (0.947) | 0.04349 (3.945) | 0.05998 (5.440) | 0.00084 (0.077) | +| crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) | + + +## Run Tests and Lint +``` +pip install -r requirements.txt +./run_checks.sh +``` + +## Optional Dependencies +SQLGlot uses [dateutil](https://github.com/dateutil/dateutil) to simplify literal timedelta expressions. The optimizer will not simplify expressions like + +```sql +x + interval '1' month +``` + +if the module cannot be found. diff --git a/benchmarks/bench.py b/benchmarks/bench.py new file mode 100644 index 0000000..cef62a8 --- /dev/null +++ b/benchmarks/bench.py @@ -0,0 +1,225 @@ +import collections.abc + +# moz_sql_parser 3.10 compatibility +collections.Iterable = collections.abc.Iterable +import gc +import timeit + +import moz_sql_parser +import numpy as np +import sqloxide +import sqlparse +import sqltree + +import sqlglot + +long = """ +SELECT + "e"."employee_id" AS "Employee #", + "e"."first_name" || ' ' || "e"."last_name" AS "Name", + "e"."email" AS "Email", + "e"."phone_number" AS "Phone", + TO_CHAR("e"."hire_date", 'MM/DD/YYYY') AS "Hire Date", + TO_CHAR("e"."salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Salary", + "e"."commission_pct" AS "Comission %", + 'works as ' || "j"."job_title" || ' in ' || "d"."department_name" || ' department (manager: ' || "dm"."first_name" || ' ' || "dm"."last_name" || ') and immediate supervisor: ' || "m"."first_name" || ' ' || "m"."last_name" AS "Current Job", + TO_CHAR("j"."min_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') || ' - ' || TO_CHAR("j"."max_salary", 'L99G999D99', 'NLS_NUMERIC_CHARACTERS = ''.,'' NLS_CURRENCY = ''$''') AS "Current Salary", + "l"."street_address" || ', ' || "l"."postal_code" || ', ' || "l"."city" || ', ' || "l"."state_province" || ', ' || "c"."country_name" || ' (' || "r"."region_name" || ')' AS "Location", + "jh"."job_id" AS "History Job ID", + 'worked from ' || TO_CHAR("jh"."start_date", 'MM/DD/YYYY') || ' to ' || TO_CHAR("jh"."end_date", 'MM/DD/YYYY') || ' as ' || "jj"."job_title" || ' in ' || "dd"."department_name" || ' department' AS "History Job Title", + case when 1 then 1 when 2 then 2 when 3 then 3 when 4 then 4 when 5 then 5 else a(b(c + 1 * 3 % 4)) end +FROM "employees" AS e +JOIN "jobs" AS j + ON "e"."job_id" = "j"."job_id" +LEFT JOIN "employees" AS m + ON "e"."manager_id" = "m"."employee_id" +LEFT JOIN "departments" AS d + ON "d"."department_id" = "e"."department_id" +LEFT JOIN "employees" AS dm + ON "d"."manager_id" = "dm"."employee_id" +LEFT JOIN "locations" AS l + ON "d"."location_id" = "l"."location_id" +LEFT JOIN "countries" AS c + ON "l"."country_id" = "c"."country_id" +LEFT JOIN "regions" AS r + ON "c"."region_id" = "r"."region_id" +LEFT JOIN "job_history" AS jh + ON "e"."employee_id" = "jh"."employee_id" +LEFT JOIN "jobs" AS jj + ON "jj"."job_id" = "jh"."job_id" +LEFT JOIN "departments" AS dd + ON "dd"."department_id" = "jh"."department_id" +ORDER BY + "e"."employee_id" +""" + +short = "select 1 as a, case when 1 then 1 when 2 then 2 else 3 end as b, c from x" + +crazy = "SELECT 1+" +crazy += "+".join(str(i) for i in range(500)) +crazy += " AS a, 2*" +crazy += "*".join(str(i) for i in range(500)) +crazy += " AS b FROM x" + +tpch = """ +WITH "_e_0" AS ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" +), "_e_1" AS ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'EUROPE' +) +SELECT + "supplier"."s_acctbal" AS "s_acctbal", + "supplier"."s_name" AS "s_name", + "nation"."n_name" AS "n_name", + "part"."p_partkey" AS "p_partkey", + "part"."p_mfgr" AS "p_mfgr", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone", + "supplier"."s_comment" AS "s_comment" +FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_mfgr" AS "p_mfgr", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size" + FROM "part" AS "part" + WHERE + "part"."p_size" = 15 + AND "part"."p_type" LIKE '%BRASS' +) AS "part" +LEFT JOIN ( + SELECT + MIN("partsupp"."ps_supplycost") AS "_col_0", + "partsupp"."ps_partkey" AS "_u_1" + FROM "_e_0" AS "partsupp" + CROSS JOIN "_e_1" AS "region" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" + ) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" + GROUP BY + "partsupp"."ps_partkey" +) AS "_u_0" + ON "part"."p_partkey" = "_u_0"."_u_1" +CROSS JOIN "_e_1" AS "region" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" +) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" +JOIN "_e_0" AS "partsupp" + ON "part"."p_partkey" = "partsupp"."ps_partkey" +JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_nationkey" AS "s_nationkey", + "supplier"."s_phone" AS "s_phone", + "supplier"."s_acctbal" AS "s_acctbal", + "supplier"."s_comment" AS "s_comment" + FROM "supplier" AS "supplier" +) AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" +WHERE + "partsupp"."ps_supplycost" = "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL +ORDER BY + "supplier"."s_acctbal" DESC, + "nation"."n_name", + "supplier"."s_name", + "part"."p_partkey" +LIMIT 100 +""" + + +def sqlglot_parse(sql): + sqlglot.parse(sql, error_level=sqlglot.ErrorLevel.IGNORE) + + +def sqltree_parse(sql): + sqltree.api.sqltree(sql.replace('"', '`').replace("''", '"')) + + +def sqlparse_parse(sql): + sqlparse.parse(sql) + + +def moz_sql_parser_parse(sql): + moz_sql_parser.parse(sql) + + +def sqloxide_parse(sql): + sqloxide.parse_sql(sql, dialect="ansi") + + +def border(columns): + columns = " | ".join(columns) + return f"| {columns} |" + + +def diff(row, column): + if column == "Query": + return "" + column = row[column] + if isinstance(column, str): + return " (N/A)" + return f" ({str(column / row['sqlglot'])[0:5]})" + + +libs = [ + "sqlglot", + "sqltree", + "sqlparse", + "moz_sql_parser", + "sqloxide", +] +table = [] + +for name, sql in {"tpch": tpch, "short": short, "long": long, "crazy": crazy}.items(): + row = {"Query": name} + table.append(row) + for lib in libs: + try: + row[lib] = np.mean(timeit.repeat(lambda: globals()[lib + "_parse"](sql), number=3)) + except: + row[lib] = "error" + +columns = ["Query"] + libs +widths = {column: max(len(column), 15) for column in columns} + +lines = [border(column.rjust(width) for column, width in widths.items())] +lines.append(border(str("-" * width) for width in widths.values())) + +for i, row in enumerate(table): + lines.append(border( + (str(row[column])[0:7] + diff(row, column)).rjust(width)[0 : width] + for column, width in widths.items() + )) + +for line in lines: + print(line) diff --git a/posts/sql_diff.md b/posts/sql_diff.md new file mode 100644 index 0000000..4d07d7f --- /dev/null +++ b/posts/sql_diff.md @@ -0,0 +1,389 @@ +# Semantic Diff for SQL +*by [Iaroslav Zeigerman](https://github.com/izeigerman)* + +## Motivation + +Software is constantly changing and evolving, and identifying what has changed and reviewing those changes is an integral part of the development process. SQL code is no exception to this. + +Text-based diff tools such as `git diff`, when applied to a code base, have certain limitations. First, they can only detect insertions and deletions, not movements or updates of individual pieces of code. Second, such tools can only detect changes between lines of text, which is too coarse for something as granular and detailed as source code. Additionally, the outcome of such a diff is dependent on the underlying code formatting, and yields different results if the formatting should change. + +Consider the following diff generated by Git: + +![Git diff output](sql_diff_images/git_diff_output.png) + +Semantically the query hasn’t changed. The two arguments `b` and `c` have been swapped (moved), posing no impact on the output of the query. Yet Git replaced the whole affected expression alongside a bulk of unrelated elements. + +The alternative to text-based diffing is to compare Abstract Syntax Trees (AST) instead. The main advantage of ASTs are that they are a direct product of code parsing, which represents the underlying code structure at any desired level of granularity. Comparing ASTs may yield extremely precise diffs; changes such as code movements and updates can also be detected. Even more importantly, this approach facilitates additional use cases beyond eyeballing two versions of source code side by side. + +The use cases I had in mind for SQL when I decided to embark on this journey of semantic diffing were the following: + +* **Query similarity score.** Identifying which parts the two queries have in common to automatically suggest opportunities for consolidation, creation of intermediate/staging tables, and so on. +* **Differentiating between cosmetic / structural changes and functional ones.** For example when a nested query is refactored into a common table expression (CTE), this kind of change doesn’t have any functional impact on either a query or its outcome. +* **Automatic suggestions about the need to retroactively backfill data.** This is especially important for pipelines that populate very large tables for which restatement is a runtime-intensive procedure. The ability to discern between simple code movements and actual modifications can help assess the impact of a change and make suggestions accordingly. + +The implementation discussed in this post is now a part of the [SQLGlot](https://github.com/tobymao/sqlglot/) library. You can find a complete source code in the [diff.py](https://github.com/tobymao/sqlglot/blob/main/sqlglot/diff.py) module. The choice of SQLglot was an obvious one due to its simple but powerful API, lack of external dependencies and, more importantly, extensive list of supported SQL dialects. + +## The Search for a Solution + +When it comes to any diffing tool (not just a semantic one), the primary challenge is to match as many elements of compared entities as possible. Once such a set of matching elements is available, deriving a sequence of changes becomes an easy task. + +If our elements have unique identifiers associated with them (for example, an element’s ID in DOM), the matching problem is trivial. However, the SQL syntax trees that we are comparing have neither unique keys nor object identifiers that can be used for the purposes of matching. So, how do we suppose to find pairs of nodes that are related? + +To better illustrate the problem, consider comparing the following SQL expressions: `SELECT a + b + c, d, e` and `SELECT a - b + c, e, f`. Matching individual nodes from respective syntax trees can be visualized as follows: + +![Figure 1: Example of node matching for two SQL expression trees](sql_diff_images/figure_1.png) +*Figure 1: Example of node matching for two SQL expression trees.* + +By looking at the figure of node matching for two SQL expression trees above, we conclude that the following changes should be captured by our solution: + +* Inserted nodes: `Sub` and `f`. These are the nodes from the target AST which do not have a matching node in the source AST. +* Removed nodes: `Add` and `d`. These are the nodes from the source AST which do not have a counterpart in the target AST. +* Remaining nodes must be identified as unchanged. + +It should be clear at this point that if we manage to match nodes in the source tree with their counterparts in the target tree, then computing the diff becomes a trivial matter. + +### Naïve Brute-Force + +The naïve solution would be to try all different permutations of node pair combinations, and see which set of pairs performs the best based on some type of heuristics. The runtime cost of such a solution quickly reaches the escape velocity; if both trees had only 10 nodes each, the number of such sets would approximately be 10! ^ 2 = 3.6M ^ 2 ~= 13 * 10^12. This is a very bad case of factorial complexity (to be precise, it’s actually much worse - O(n! ^ 2) - but I couldn’t come up with a name for it), so there is little need to explore this approach any further. + +### Myers Algorithm + +After the naïve approach was proven to be infeasible, the next question I asked myself was “how does git diff work?”. This question led me to discover the Myers diff algorithm [1]. This algorithm has been designed to compare sequences of strings. At its core, it’s looking for the shortest path on a graph of possible edits that transform the first sequence into the second one, while heavily rewarding those paths that lead to longest subsequences of unchanged elements. There’s a lot of material out there describing this algorithm in greater detail. I found James Coglan’s series of [blog posts](https://blog.jcoglan.com/2017/02/12/the-myers-diff-algorithm-part-1/) to be the most comprehensive. + +Therefore, I had this “brilliant” (actually not) idea to transform trees into sequences by traversing them in topological order, and then applying the Myers algorithm on resulting sequences while using a custom heuristics when checking the equality of two nodes. Unsurprisingly, comparing sequences of strings is quite different from comparing hierarchical tree structures, and by flattening trees into sequences, we lose a lot of relevant context. This resulted in a terrible performance of this algorithm on ASTs. It often matched completely unrelated nodes, even when the two trees were mostly the same, and produced extremely inaccurate lists of changes overall. After playing around with it a little and tweaking my equality heuristics to improve accuracy, I ultimately scrapped the whole implementation and went back to the drawing board. + +## Change Distiller + +The algorithm I settled on at the end was Change Distiller, created by Fluri et al. [2], which in turn is an improvement over the core idea described by Chawathe et al. [3]. + +The algorithm consists of two high-level steps: + +1. **Finding appropriate matchings between pairs of nodes that are part of compared ASTs.** Identifying what is meant by “appropriate” matching is also a part of this step. +2. **Generating the so-called “edit script” from the matching set built in the 1st step.** The edit script is a sequence of edit operations (for example, insert, remove, update, etc.) on individual tree nodes, such that when applied as transformations on the source AST, it eventually becomes the target AST. In general, the shorter the sequence, the better. The length of the edit script can be used to compare the performance of different algorithms, though this is not the only metric that matters. + +The rest of this section is dedicated to the Python implementation of the steps above using the AST implementation provided by the SQLGlot library. + +### Building the Matching Set +#### Matching Leaves + +We begin composing the matching set by matching the leaf nodes. Leaf nodes are the nodes that do not have any children nodes (such as literals, identifiers, etc.). In order to match them, we gather all the leaf nodes from the source tree and generate a cartesian product with all the leaves from the target tree, while comparing pairs created this way and assigning them a similarity score. During this stage, we also exclude pairs that don’t pass basic matching criteria. Then, we pick pairs that scored the highest while making sure that each node is matched no more than once. + +Using the example provided at the beginning of the post, the process of building an initial set of candidate matchings can be seen on Figure 2. + +![Figure 2: Building a set of candidate matchings between leaf nodes. The third item in each triplet represents a similarity score between two nodes.](sql_diff_images/figure_2.gif) +*Figure 2: Building a set of candidate matchings between leaf nodes. The third item in each triplet represents a similarity score between two nodes.* + +First, let’s analyze the similarity score. Then, we’ll discuss matching criteria. + +The similarity score proposed by Fluri et al. [2] is a [dice coefficient ](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient)applied to [bigrams](https://en.wikipedia.org/wiki/Bigram) of respective node values. A bigram is a sequence of two adjacent elements from a string computed in a sliding window fashion: + +```python +def bigram(string): + count = max(0, len(string) - 1) + return [string[i : i + 2] for i in range(count)] +``` + +For reasons that will become clear shortly, we actually need to compute bigram histograms rather than just sequences: + +```python +from collections import defaultdict + +def bigram_histo(string): + count = max(0, len(string) - 1) + bigram_histo = defaultdict(int) + for i in range(count): + bigram_histo[string[i : i + 2]] += 1 + return bigram_histo +``` + +The dice coefficient formula looks like following: + +![Dice Coefficient](sql_diff_images/dice_coef.png) + +Where X is a bigram of the source node and Y is a bigram of the second one. What this essentially does is count the number of bigram elements the two nodes have in common, multiply it by 2, and then divide by the total number of elements in both bigrams. This is where bigram histograms come in handy: + +```python +def dice_coefficient(source, target): + source_histo = bigram_histo(source.sql()) + target_histo = bigram_histo(target.sql()) + + total_grams = ( + sum(source_histo.values()) + sum(target_histo.values()) + ) + if not total_grams: + return 1.0 if source == target else 0.0 + + overlap_len = 0 + overlapping_grams = set(source_histo) & set(target_histo) + for g in overlapping_grams: + overlap_len += min(source_histo[g], target_histo[g]) + + return 2 * overlap_len / total_grams +``` + +To compute a bigram given a tree node, we first transform the node into its canonical SQL representation,so that the `Literal(123)` node becomes just “123” and the `Identifier(“a”)` node becomes just “a”. We also handle a scenario when strings are too short to derive bigrams. In this case, we fallback to checking the two nodes for equality. + +Now when we know how to compute the similarity score, we can take care of the matching criteria for leaf nodes. In the original paper [2], the matching criteria is formalized as follows: + +![Matching criteria for leaf nodes](sql_diff_images/matching_criteria_1.png) + +The two nodes are matched if two conditions are met: + +1. The node labels match (in our case labels are just node types). +2. The similarity score for node values is greater than or equal to some threshold “f”. The authors of the paper recommend setting the value of “f” to 0.6. + +With building blocks in place, we can now build a matching set for leaf nodes. First, we generate a list of candidates for matching: + +```python +from heapq import heappush, heappop + +candidate_matchings = [] +source_leaves = _get_leaves(self._source) +target_leaves = _get_leaves(self._target) +for source_leaf in source_leaves: + for target_leaf in target_leaves: + if _is_same_type(source_leaf, target_leaf): + similarity_score = dice_coefficient( + source_leaf, target_leaf + ) + if similarity_score >= 0.6: + heappush( + candidate_matchings, + ( + -similarity_score, + len(candidate_matchings), + source_leaf, + target_leaf, + ), + ) +``` + +In the implementation above, we push each matching pair onto the heap to automatically maintain the correct order based on the assigned similarity score. + +Finally, we build the initial matching set by picking leaf pairs with the highest score: + +```python +matching_set = set() +while candidate_matchings: + _, _, source_leaf, target_leaf = heappop(candidate_matchings) + if ( + source_leaf in unmatched_source_nodes + and target_leaf in unmatched_target_nodes + ): + matching_set.add((source_leaf, target_leaf)) + unmatched_source_nodes.remove(source_leaf) + unmatched_target_nodes.remove(target_leaf) +``` + +To finalize the matching set, we should now proceed with matching inner nodes. + +#### Matching Inner Nodes + +Matching inner nodes is quite similar to matching leaf nodes, with the following two distinctions: + +* Rather than ranking a set of possible candidates, we pick the first node pair that passes the matching criteria. +* The matching criteria itself has been extended to account for the number of leaf nodes the pair of inner nodes have in common. + +![Figure 3: Matching inner nodes based on their type as well as how many of their leaf nodes have been previously matched.](sql_diff_images/figure_3.gif) +*Figure 3: Matching inner nodes based on their type as well as how many of their leaf nodes have been previously matched.* + +Let’s start with the matching criteria. The criteria is formalized as follows: + +![Matching criteria for inner nodes](sql_diff_images/matching_criteria_2.png) + +Alongside already familiar similarity score and node type criteria, there is a new one in the middle: the ratio of leaf nodes that the two nodes have in common must exceed some threshold “t”. The recommended value for “t” is also 0.6. Counting the number of common leaf nodes is pretty straightforward, since we already have the complete matching set for leaves. All we need to do is count how many matching pairs do leaf nodes from the two compared inner nodes form. + +There are two additional heuristics associated with this matching criteria: + +* Inner node similarity weighting: if the similarity score between the node values doesn’t pass the threshold “f” but the ratio of common leaf nodes (“t”) is greater than or equal to 0.8, then the matching is considered successful. +* The threshold “t” is reduced to 0.4 for inner nodes with the number of leaf nodes equal to 4 or less, in order to decrease the false negative rate for small subtrees. + +We now only have to iterate through the remaining unmatched nodes and form matching pairs based on the outlined criteria: + +```python +leaves_matching_set = matching_set.copy() + +for source_node in unmatched_source_nodes.copy(): + for target_node in unmatched_target_nodes: + if _is_same_type(source_node, target_node): + source_leaves = set(_get_leaves(source_node)) + target_leaves = set(_get_leaves(target_node)) + + max_leaves_num = max(len(source_leaves), len(target_leaves)) + if max_leaves_num: + common_leaves_num = sum( + 1 if s in source_leaves and t in target_leaves else 0 + for s, t in leaves_matching_set + ) + leaf_similarity_score = common_leaves_num / max_leaves_num + else: + leaf_similarity_score = 0.0 + + adjusted_t = ( + 0.6 + if min(len(source_leaves), len(target_leaves)) > 4 + else 0.4 + ) + + if leaf_similarity_score >= 0.8 or ( + leaf_similarity_score >= adjusted_t + and dice_coefficient(source_node, target_node) >= 0.6 + ): + matching_set.add((source_node, target_node)) + unmatched_source_nodes.remove(source_node) + unmatched_target_nodes.remove(target_node) + break +``` + +After the matching set is formed, we can proceed with generation of the edit script, which will be the algorithm’s output. + +### Generating the Edit Script + +At this point, we should have the following 3 sets at our disposal: + +* The set of matched node pairs. +* The set of remaining unmatched nodes from the source tree. +* The set of remaining unmatched nodes from the target tree. + +We can derive 3 kinds of edits from the matching set: either the node’s value was updated (**Update**), the node was moved to a different position within the tree (**Move**), or the node remained unchanged (**Keep**). Note that the **Move** case is not mutually exclusive with the other two. The node could have been updated or could have remained the same while at the same time its position within its parent node or the parent node itself could have changed. All unmatched nodes from the source tree are the ones that were removed (**Remove**), while unmatched nodes from the target tree are the ones that were inserted (**Insert**). + +The latter two cases are pretty straightforward to implement: + +```python +edit_script = [] + +for removed_node in unmatched_source_nodes: + edit_script.append(Remove(removed_node)) +for inserted_node in unmatched_target_nodes: + edit_script.append(Insert(inserted_node)) +``` + +Traversing the matching set requires a little more thought: + +```python +for source_node, target_node in matching_set: + if ( + not isinstance(source_node, LEAF_EXPRESSION_TYPES) + or source_node == target_node + ): + move_edits = generate_move_edits( + source_node, target_node, matching_set + ) + edit_script.extend(move_edits) + edit_script.append(Keep(source_node, target_node)) + else: + edit_script.append(Update(source_node, target_node)) +``` + +If a matching pair represents a pair of leaf nodes, we check if they are the same to decide whether an update took place. For inner node pairs, we also need to compare the positions of their respective children to detect node movements. Chawathe et al. [3] suggest applying the [longest common subsequence ](https://en.wikipedia.org/wiki/Longest_common_subsequence_problem)(LCS) algorithm which, no surprise here, was described by Myers himself [1]. There is a small catch, however: instead of checking the equality of two children nodes, we need to check whether the two nodes form a pair that is a part of our matching set. + +Now with this knowledge, the implementation becomes straightforward: + +```python +def generate_move_edits(source, target, matching_set): + source_children = _get_child_nodes(source) + target_children = _get_child_nodes(target) + + lcs = set( + _longest_common_subsequence( + source_children, + target_children, + lambda l, r: (l, r) in matching_set + ) + ) + + move_edits = [] + for node in source_children: + if node not in lcs and node not in unmatched_source_nodes: + move_edits.append(Move(node)) + + return move_edits +``` + +I left out the implementation of the LCS algorithm itself here, but there are plenty of implementation choices out there that can be easily looked up. + +### Output + +The implemented algorithm produces the output that resembles the following: + +```python +>>> from sqlglot import parse_one, diff +>>> diff(parse_one("SELECT a + b + c, d, e"), parse_one("SELECT a - b + c, e, f")) + +Remove(Add) +Remove(Column(d)) +Remove(Identifier(d)) +Insert(Sub) +Insert(Column(f)) +Insert(Identifier(f)) +Keep(Select, Select) +Keep(Add, Add) +Keep(Column(a), Column(a)) +Keep(Identifier(a), Identifier(a)) +Keep(Column(b), Column(b)) +Keep(Identifier(b), Identifier(b)) +Keep(Column(c), Column(c)) +Keep(Identifier(c), Identifier(c)) +Keep(Column(e), Column(e)) +Keep(Identifier(e), Identifier(e)) +``` +Note that the output above is abbreviated. The string representation of actual AST nodes is significantly more verbose. + +The implementation works especially well when coupled with the SQLGlot’s query optimizer which can be used to produce canonical representations of compared queries: + +```python +>>> schema={"t": {"a": "INT", "b": "INT", "c": "INT", "d": "INT"}} +>>> source = """ +... SELECT 1 + 1 + a +... FROM t +... WHERE b = 1 OR (c = 2 AND d = 3) +... """ +>>> target = """ +... SELECT 2 + a +... FROM t +... WHERE (b = 1 OR c = 2) AND (b = 1 OR d = 3) +... """ +>>> optimized_source = optimize(parse_one(source), schema=schema) +>>> optimized_target = optimize(parse_one(target), schema=schema) +>>> edit_script = diff(optimized_source, optimized_target) +>>> sum(0 if isinstance(e, Keep) else 1 for e in edit_script) +0 +``` + +### Optimizations + +The worst case runtime complexity of this algorithm is not exactly stellar: O(n^2 * log n^2). This is because of the leaf matching process, which involves ranking a cartesian product between all leaf nodes of compared trees. Unsurprisingly, the algorithm takes a considerable time to finish for bigger queries. + +There are still a few basic things we can do in our implementation to help improve performance: + +* Refer to individual node objects using their identifiers (Python’s [id()](https://docs.python.org/3/library/functions.html#id)) instead of direct references in sets. This helps avoid costly recursive hash calculations and equality checks. +* Cache bigram histograms to avoid computing them more than once for the same node. +* Compute the canonical SQL string representation for each tree once while caching string representations of all inner nodes. This prevents redundant tree traversals when bigrams are computed. + +At the time of writing only the first two optimizations have been implemented, so there is an opportunity to contribute for anyone who’s interested. + +## Alternative Solutions + +This section is dedicated to solutions that I’ve investigated, but haven’t tried. + +First, this section wouldn’t be complete without Tristan Hume’s [blog post](https://thume.ca/2017/06/17/tree-diffing/). Tristan’s solution has a lot in common with the Myers algorithm plus heuristics that is much more clever than what I came up with. The implementation relies on a combination of [dynamic programming](https://en.wikipedia.org/wiki/Dynamic_programming) and [A* search algorithm](https://en.wikipedia.org/wiki/A*_search_algorithm) to explore the space of possible matchings and pick the best ones. It seemed to have worked well for Tistan’s specific use case, but after my negative experience with the Myers algorithm, I decided to try something different. + +Another notable approach is the Gumtree algorithm by Falleri et al. [4]. I discovered this paper after I’d already implemented the algorithm that is the main focus of this post. In sections 5.2 and 5.3 of their paper, the authors compare the two algorithms side by side and claim that Gumtree is significantly better in terms of both runtime performance and accuracy when evaluated on 12 792 pairs of Java source files. This doesn’t surprise me, as the algorithm takes the height of subtrees into account. In my tests, I definitely saw scenarios in which this context would have helped. On top of that, the authors promise O(n^2) runtime complexity in the worst case which, given the Change Distiller's O(n^2 * log n^2), looks particularly tempting. I hope to try this algorithm out at some point, and there is a good chance you see me writing about it in my future posts. + +## Conclusion + +The Change Distiller algorithm yielded quite satisfactory results in most of my tests. The scenarios in which it fell short mostly concerned identical (or very similar) subtrees located in different parts of the AST. In those cases, node mismatches were frequent and, as a result, edit scripts were somewhat suboptimal. + +Additionally, the runtime performance of the algorithm leaves a lot to be desired. On trees with 1000 leaf nodes each, the algorithm takes a little under 2 seconds to complete. My implementation still has room for improvement, but this should give you a rough idea of what to expect. It appears that the Gumtree algorithm [4] can help address both of these points. I hope to find bandwidth to work on it soon and then compare the two algorithms side-by-side to find out which one performs better on SQL specifically. In the meantime, Change Distiller definitely gets the job done, and I can now proceed with applying it to some of the use cases I mentioned at the beginning of this post. + +I’m also curious to learn whether other folks in the industry faced a similar problem, and how they approached it. If you did something similar, I’m interested to hear about your experience. + +## References + +[1] Eugene W. Myers. [An O(ND) Difference Algorithm and Its Variations](http://www.xmailserver.org/diff2.pdf). Algorithmica 1(2): 251-266 (1986) + +[2] B. Fluri, M. Wursch, M. Pinzger, and H. Gall. [Change Distilling: Tree differencing for fine-grained source code change extraction](https://www.researchgate.net/publication/3189787_Change_DistillingTree_Differencing_for_Fine-Grained_Source_Code_Change_Extraction). IEEE Trans. Software Eng., 33(11):725–743, 2007. + +[3] S.S. Chawathe, A. Rajaraman, H. Garcia-Molina, and J. Widom. [Change Detection in Hierarchically Structured Information](http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf). Proc. ACM Sigmod Int’l Conf. Management of Data, pp. 493-504, June 1996 + +[4] Jean-Rémy Falleri, Floréal Morandat, Xavier Blanc, Matias Martinez, Martin Monperrus. [Fine-grained and Accurate Source Code Differencing](https://hal.archives-ouvertes.fr/hal-01054552/document). Proceedings of the International Conference on Automated Software Engineering, 2014, Västeras, Sweden. pp.313-324, 10.1145/2642937.2642982. hal-01054552 diff --git a/posts/sql_diff_images/dice_coef.png b/posts/sql_diff_images/dice_coef.png Binary files differnew file mode 100644 index 0000000..e0a91f6 --- /dev/null +++ b/posts/sql_diff_images/dice_coef.png diff --git a/posts/sql_diff_images/figure_1.png b/posts/sql_diff_images/figure_1.png Binary files differnew file mode 100644 index 0000000..578109b --- /dev/null +++ b/posts/sql_diff_images/figure_1.png diff --git a/posts/sql_diff_images/figure_2.gif b/posts/sql_diff_images/figure_2.gif Binary files differnew file mode 100644 index 0000000..9a3f3c0 --- /dev/null +++ b/posts/sql_diff_images/figure_2.gif diff --git a/posts/sql_diff_images/figure_3.gif b/posts/sql_diff_images/figure_3.gif Binary files differnew file mode 100644 index 0000000..4116154 --- /dev/null +++ b/posts/sql_diff_images/figure_3.gif diff --git a/posts/sql_diff_images/git_diff_output.png b/posts/sql_diff_images/git_diff_output.png Binary files differnew file mode 100644 index 0000000..8ba155c --- /dev/null +++ b/posts/sql_diff_images/git_diff_output.png diff --git a/posts/sql_diff_images/matching_criteria_1.png b/posts/sql_diff_images/matching_criteria_1.png Binary files differnew file mode 100644 index 0000000..9a321f2 --- /dev/null +++ b/posts/sql_diff_images/matching_criteria_1.png diff --git a/posts/sql_diff_images/matching_criteria_2.png b/posts/sql_diff_images/matching_criteria_2.png Binary files differnew file mode 100644 index 0000000..2a1c7f2 --- /dev/null +++ b/posts/sql_diff_images/matching_criteria_2.png diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..b2308e5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +autoflake +black +duckdb +isort +pandas +python-dateutil diff --git a/run_checks.sh b/run_checks.sh new file mode 100755 index 0000000..a7dddf4 --- /dev/null +++ b/run_checks.sh @@ -0,0 +1,12 @@ +#!/bin/bash -e + +python -m autoflake -i -r \ + --expand-star-imports \ + --remove-all-unused-imports \ + --ignore-init-module-imports \ + --remove-duplicate-keys \ + --remove-unused-variables \ + sqlglot/ tests/ +python -m isort --profile black sqlglot/ tests/ +python -m black sqlglot/ tests/ +python -m unittest diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..4803b7e --- /dev/null +++ b/setup.py @@ -0,0 +1,33 @@ +from setuptools import find_packages, setup + +version = ( + open("sqlglot/__init__.py") + .read() + .split("__version__ = ")[-1] + .split("\n")[0] + .strip("") + .strip("'") + .strip('"') +) + +setup( + name="sqlglot", + version=version, + description="An easily customizable SQL parser and transpiler", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + url="https://github.com/tobymao/sqlglot", + author="Toby Mao", + author_email="toby.mao@gmail.com", + license="MIT", + packages=find_packages(include=["sqlglot", "sqlglot.*"]), + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: SQL", + "Programming Language :: Python :: 3 :: Only", + ], +) diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py new file mode 100644 index 0000000..0007e34 --- /dev/null +++ b/sqlglot/__init__.py @@ -0,0 +1,96 @@ +from sqlglot import expressions as exp +from sqlglot.dialects import Dialect, Dialects +from sqlglot.diff import diff +from sqlglot.errors import ErrorLevel, ParseError, TokenError, UnsupportedError +from sqlglot.expressions import Expression +from sqlglot.expressions import alias_ as alias +from sqlglot.expressions import ( + and_, + column, + condition, + from_, + maybe_parse, + not_, + or_, + select, + subquery, +) +from sqlglot.expressions import table_ as table +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + +__version__ = "6.0.4" + +pretty = False + + +def parse(sql, read=None, **opts): + """ + Parses the given SQL string into a collection of syntax trees, one per + parsed SQL statement. + + Args: + sql (str): the SQL code string to parse. + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + **opts: other options. + + Returns: + typing.List[Expression]: the list of parsed syntax trees. + """ + dialect = Dialect.get_or_raise(read)() + return dialect.parse(sql, **opts) + + +def parse_one(sql, read=None, into=None, **opts): + """ + Parses the given SQL string and returns a syntax tree for the first + parsed SQL statement. + + Args: + sql (str): the SQL code string to parse. + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + into (Expression): the SQLGlot Expression to parse into + **opts: other options. + + Returns: + Expression: the syntax tree for the first parsed statement. + """ + + dialect = Dialect.get_or_raise(read)() + + if into: + result = dialect.parse_into(into, sql, **opts) + else: + result = dialect.parse(sql, **opts) + + return result[0] if result else None + + +def transpile(sql, read=None, write=None, identity=True, error_level=None, **opts): + """ + Parses the given SQL string using the source dialect and returns a list of SQL strings + transformed to conform to the target dialect. Each string in the returned list represents + a single transformed SQL statement. + + Args: + sql (str): the SQL code string to transpile. + read (str): the source dialect used to parse the input string + (eg. "spark", "hive", "presto", "mysql"). + write (str): the target dialect into which the input should be transformed + (eg. "spark", "hive", "presto", "mysql"). + identity (bool): if set to True and if the target dialect is not specified + the source dialect will be used as both: the source and the target dialect. + error_level (ErrorLevel): the desired error level of the parser. + **opts: other options. + + Returns: + typing.List[str]: the list of transpiled SQL statements / expressions. + """ + write = write or read if identity else write + return [ + Dialect.get_or_raise(write)().generate(expression, **opts) + for expression in parse(sql, read, error_level=error_level) + ] diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py new file mode 100644 index 0000000..25200c4 --- /dev/null +++ b/sqlglot/__main__.py @@ -0,0 +1,69 @@ +import argparse + +import sqlglot + +parser = argparse.ArgumentParser(description="Transpile SQL") +parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile") +parser.add_argument( + "--read", + dest="read", + type=str, + default=None, + help="Dialect to read default is generic", +) +parser.add_argument( + "--write", + dest="write", + type=str, + default=None, + help="Dialect to write default is generic", +) +parser.add_argument( + "--no-identify", + dest="identify", + action="store_false", + help="Don't auto identify fields", +) +parser.add_argument( + "--no-pretty", + dest="pretty", + action="store_false", + help="Compress sql", +) +parser.add_argument( + "--parse", + dest="parse", + action="store_true", + help="Parse and return the expression tree", +) +parser.add_argument( + "--error-level", + dest="error_level", + type=str, + default="RAISE", + help="IGNORE, WARN, RAISE (default)", +) + + +args = parser.parse_args() +error_level = sqlglot.ErrorLevel[args.error_level.upper()] + +if args.parse: + sqls = [ + repr(expression) + for expression in sqlglot.parse( + args.sql, read=args.read, error_level=error_level + ) + ] +else: + sqls = sqlglot.transpile( + args.sql, + read=args.read, + write=args.write, + identify=args.identify, + pretty=args.pretty, + error_level=error_level, + ) + +for sql in sqls: + print(sql) diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py new file mode 100644 index 0000000..5aa7d77 --- /dev/null +++ b/sqlglot/dialects/__init__.py @@ -0,0 +1,15 @@ +from sqlglot.dialects.bigquery import BigQuery +from sqlglot.dialects.clickhouse import ClickHouse +from sqlglot.dialects.dialect import Dialect, Dialects +from sqlglot.dialects.duckdb import DuckDB +from sqlglot.dialects.hive import Hive +from sqlglot.dialects.mysql import MySQL +from sqlglot.dialects.oracle import Oracle +from sqlglot.dialects.postgres import Postgres +from sqlglot.dialects.presto import Presto +from sqlglot.dialects.snowflake import Snowflake +from sqlglot.dialects.spark import Spark +from sqlglot.dialects.sqlite import SQLite +from sqlglot.dialects.starrocks import StarRocks +from sqlglot.dialects.tableau import Tableau +from sqlglot.dialects.trino import Trino diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py new file mode 100644 index 0000000..f4e87c3 --- /dev/null +++ b/sqlglot/dialects/bigquery.py @@ -0,0 +1,128 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + inline_array_sql, + no_ilike_sql, + rename_func, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_add(expression_class): + def func(args): + interval = list_get(args, 1) + return expression_class( + this=list_get(args, 0), + expression=interval.this, + unit=interval.args.get("unit"), + ) + + return func + + +def _date_add_sql(data_type, kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = self.sql(expression, "unit") or "'day'" + expression = self.sql(expression, "expression") + return f"{data_type}_{kind}({this}, INTERVAL {expression} {unit})" + + return func + + +class BigQuery(Dialect): + unnest_column_only = True + + class Tokenizer(Tokenizer): + QUOTES = [ + (prefix + quote, quote) if prefix else quote + for quote in ["'", '"', '"""', "'''"] + for prefix in ["", "r", "R"] + ] + IDENTIFIERS = ["`"] + ESCAPE = "\\" + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "CURRENT_DATETIME": TokenType.CURRENT_DATETIME, + "CURRENT_TIME": TokenType.CURRENT_TIME, + "GEOGRAPHY": TokenType.GEOGRAPHY, + "INT64": TokenType.BIGINT, + "FLOAT64": TokenType.DOUBLE, + "QUALIFY": TokenType.QUALIFY, + "UNKNOWN": TokenType.NULL, + "WINDOW": TokenType.WINDOW, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "DATE_ADD": _date_add(exp.DateAdd), + "DATETIME_ADD": _date_add(exp.DatetimeAdd), + "TIME_ADD": _date_add(exp.TimeAdd), + "TIMESTAMP_ADD": _date_add(exp.TimestampAdd), + "DATE_SUB": _date_add(exp.DateSub), + "DATETIME_SUB": _date_add(exp.DatetimeSub), + "TIME_SUB": _date_add(exp.TimeSub), + "TIMESTAMP_SUB": _date_add(exp.TimestampSub), + } + + NO_PAREN_FUNCTIONS = { + **Parser.NO_PAREN_FUNCTIONS, + TokenType.CURRENT_DATETIME: exp.CurrentDatetime, + TokenType.CURRENT_TIME: exp.CurrentTime, + } + + class Generator(Generator): + TRANSFORMS = { + exp.Array: inline_array_sql, + exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.DateAdd: _date_add_sql("DATE", "ADD"), + exp.DateSub: _date_add_sql("DATE", "SUB"), + exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), + exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), + exp.ILike: no_ilike_sql, + exp.TimeAdd: _date_add_sql("TIME", "ADD"), + exp.TimeSub: _date_add_sql("TIME", "SUB"), + exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), + exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), + exp.VariancePop: rename_func("VAR_POP"), + } + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "INT64", + exp.DataType.Type.SMALLINT: "INT64", + exp.DataType.Type.INT: "INT64", + exp.DataType.Type.BIGINT: "INT64", + exp.DataType.Type.DECIMAL: "NUMERIC", + exp.DataType.Type.FLOAT: "FLOAT64", + exp.DataType.Type.DOUBLE: "FLOAT64", + exp.DataType.Type.BOOLEAN: "BOOL", + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARCHAR: "STRING", + exp.DataType.Type.NVARCHAR: "STRING", + } + + def in_unnest_op(self, unnest): + return self.sql(unnest) + + def union_op(self, expression): + return f"UNION{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def except_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("EXCEPT without DISTINCT is not supported in BigQuery") + return f"EXCEPT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + + def intersect_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported( + "INTERSECT without DISTINCT is not supported in BigQuery" + ) + return ( + f"INTERSECT{' DISTINCT' if expression.args.get('distinct') else ' ALL'}" + ) diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py new file mode 100644 index 0000000..55dad7a --- /dev/null +++ b/sqlglot/dialects/clickhouse.py @@ -0,0 +1,48 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +class ClickHouse(Dialect): + normalize_functions = None + null_ordering = "nulls_are_last" + + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', "`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "NULLABLE": TokenType.NULLABLE, + "FINAL": TokenType.FINAL, + "INT8": TokenType.TINYINT, + "INT16": TokenType.SMALLINT, + "INT32": TokenType.INT, + "INT64": TokenType.BIGINT, + "FLOAT32": TokenType.FLOAT, + "FLOAT64": TokenType.DOUBLE, + } + + class Parser(Parser): + def _parse_table(self, schema=False): + this = super()._parse_table(schema) + + if self._match(TokenType.FINAL): + this = self.expression(exp.Final, this=this) + + return this + + class Generator(Generator): + STRUCT_DELIMITER = ("(", ")") + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.NULLABLE: "Nullable", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.Array: inline_array_sql, + exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py new file mode 100644 index 0000000..8045f7a --- /dev/null +++ b/sqlglot/dialects/dialect.py @@ -0,0 +1,268 @@ +from enum import Enum + +from sqlglot import exp +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.time import format_time +from sqlglot.tokens import Tokenizer +from sqlglot.trie import new_trie + + +class Dialects(str, Enum): + DIALECT = "" + + BIGQUERY = "bigquery" + CLICKHOUSE = "clickhouse" + DUCKDB = "duckdb" + HIVE = "hive" + MYSQL = "mysql" + ORACLE = "oracle" + POSTGRES = "postgres" + PRESTO = "presto" + SNOWFLAKE = "snowflake" + SPARK = "spark" + SQLITE = "sqlite" + STARROCKS = "starrocks" + TABLEAU = "tableau" + TRINO = "trino" + + +class _Dialect(type): + classes = {} + + @classmethod + def __getitem__(cls, key): + return cls.classes[key] + + @classmethod + def get(cls, key, default=None): + return cls.classes.get(key, default) + + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + enum = Dialects.__members__.get(clsname.upper()) + cls.classes[enum.value if enum is not None else clsname.lower()] = klass + + klass.time_trie = new_trie(klass.time_mapping) + klass.inverse_time_mapping = {v: k for k, v in klass.time_mapping.items()} + klass.inverse_time_trie = new_trie(klass.inverse_time_mapping) + + klass.tokenizer_class = getattr(klass, "Tokenizer", Tokenizer) + klass.parser_class = getattr(klass, "Parser", Parser) + klass.generator_class = getattr(klass, "Generator", Generator) + + klass.tokenizer = klass.tokenizer_class() + klass.quote_start, klass.quote_end = list(klass.tokenizer_class.QUOTES.items())[ + 0 + ] + klass.identifier_start, klass.identifier_end = list( + klass.tokenizer_class.IDENTIFIERS.items() + )[0] + + return klass + + +class Dialect(metaclass=_Dialect): + index_offset = 0 + unnest_column_only = False + alias_post_tablesample = False + normalize_functions = "upper" + null_ordering = "nulls_are_small" + + date_format = "'%Y-%m-%d'" + dateint_format = "'%Y%m%d'" + time_format = "'%Y-%m-%d %H:%M:%S'" + time_mapping = {} + + # autofilled + quote_start = None + quote_end = None + identifier_start = None + identifier_end = None + + time_trie = None + inverse_time_mapping = None + inverse_time_trie = None + tokenizer_class = None + parser_class = None + generator_class = None + tokenizer = None + + @classmethod + def get_or_raise(cls, dialect): + if not dialect: + return cls + result = cls.get(dialect) + if not result: + raise ValueError(f"Unknown dialect '{dialect}'") + return result + + @classmethod + def format_time(cls, expression): + if isinstance(expression, str): + return exp.Literal.string( + format_time( + expression[1:-1], # the time formats are quoted + cls.time_mapping, + cls.time_trie, + ) + ) + if expression and expression.is_string: + return exp.Literal.string( + format_time( + expression.this, + cls.time_mapping, + cls.time_trie, + ) + ) + return expression + + def parse(self, sql, **opts): + return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) + + def parse_into(self, expression_type, sql, **opts): + return self.parser(**opts).parse_into( + expression_type, self.tokenizer.tokenize(sql), sql + ) + + def generate(self, expression, **opts): + return self.generator(**opts).generate(expression) + + def transpile(self, code, **opts): + return self.generate(self.parse(code), **opts) + + def parser(self, **opts): + return self.parser_class( + **{ + "index_offset": self.index_offset, + "unnest_column_only": self.unnest_column_only, + "alias_post_tablesample": self.alias_post_tablesample, + "null_ordering": self.null_ordering, + **opts, + }, + ) + + def generator(self, **opts): + return self.generator_class( + **{ + "quote_start": self.quote_start, + "quote_end": self.quote_end, + "identifier_start": self.identifier_start, + "identifier_end": self.identifier_end, + "escape": self.tokenizer_class.ESCAPE, + "index_offset": self.index_offset, + "time_mapping": self.inverse_time_mapping, + "time_trie": self.inverse_time_trie, + "unnest_column_only": self.unnest_column_only, + "alias_post_tablesample": self.alias_post_tablesample, + "normalize_functions": self.normalize_functions, + "null_ordering": self.null_ordering, + **opts, + } + ) + + +def rename_func(name): + return ( + lambda self, expression: f"{name}({csv(*[self.sql(e) for e in expression.args.values()])})" + ) + + +def approx_count_distinct_sql(self, expression): + if expression.args.get("accuracy"): + self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") + return f"APPROX_COUNT_DISTINCT({self.sql(expression, 'this')})" + + +def if_sql(self, expression): + expressions = csv( + self.sql(expression, "this"), + self.sql(expression, "true"), + self.sql(expression, "false"), + ) + return f"IF({expressions})" + + +def arrow_json_extract_sql(self, expression): + return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}" + + +def arrow_json_extract_scalar_sql(self, expression): + return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}" + + +def inline_array_sql(self, expression): + return f"[{self.expressions(expression)}]" + + +def no_ilike_sql(self, expression): + return self.like_sql( + exp.Like( + this=exp.Lower(this=expression.this), + expression=expression.args["expression"], + ) + ) + + +def no_paren_current_date_sql(self, expression): + zone = self.sql(expression, "this") + return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" + + +def no_recursive_cte_sql(self, expression): + if expression.args.get("recursive"): + self.unsupported("Recursive CTEs are unsupported") + expression.args["recursive"] = False + return self.with_sql(expression) + + +def no_safe_divide_sql(self, expression): + n = self.sql(expression, "this") + d = self.sql(expression, "expression") + return f"IF({d} <> 0, {n} / {d}, NULL)" + + +def no_tablesample_sql(self, expression): + self.unsupported("TABLESAMPLE unsupported") + return self.sql(expression.this) + + +def no_trycast_sql(self, expression): + return self.cast_sql(expression) + + +def str_position_sql(self, expression): + this = self.sql(expression, "this") + substr = self.sql(expression, "substr") + position = self.sql(expression, "position") + if position: + return f"STRPOS(SUBSTR({this}, {position}), {substr}) + {position} - 1" + return f"STRPOS({this}, {substr})" + + +def struct_extract_sql(self, expression): + this = self.sql(expression, "this") + struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) + return f"{this}.{struct_key}" + + +def format_time_lambda(exp_class, dialect, default=None): + """Helper used for time expressions. + + Args + exp_class (Class): the expression class to instantiate + dialect (string): sql dialect + default (Option[bool | str]): the default format, True being time + """ + + def _format_time(args): + return exp_class( + this=list_get(args, 0), + format=Dialect[dialect].format_time( + list_get(args, 1) + or (Dialect[dialect].time_format if default is True else default) + ), + ) + + return _format_time diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py new file mode 100644 index 0000000..d83a620 --- /dev/null +++ b/sqlglot/dialects/duckdb.py @@ -0,0 +1,156 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + approx_count_distinct_sql, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + format_time_lambda, + no_safe_divide_sql, + no_tablesample_sql, + rename_func, + str_position_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _unix_to_time(self, expression): + return f"TO_TIMESTAMP(CAST({self.sql(expression, 'this')} AS BIGINT))" + + +def _str_to_time_sql(self, expression): + return f"STRPTIME({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_add(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit").strip("'") or "DAY" + return f"CAST({this} AS DATE) + INTERVAL {e} {unit}" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (DuckDB.time_format, DuckDB.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return f"CAST({self.sql(expression, 'this')} AS DATE)" + + +def _date_add(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit").strip("'") or "DAY" + return f"{this} + INTERVAL {e} {unit}" + + +def _array_sort_sql(self, expression): + if expression.expression: + self.unsupported("DUCKDB ARRAY_SORT does not support a comparator") + return f"ARRAY_SORT({self.sql(expression, 'this')})" + + +def _sort_array_sql(self, expression): + this = self.sql(expression, "this") + if expression.args.get("asc") == exp.FALSE: + return f"ARRAY_REVERSE_SORT({this})" + return f"ARRAY_SORT({this})" + + +def _sort_array_reverse(args): + return exp.SortArray(this=list_get(args, 0), asc=exp.FALSE) + + +def _struct_pack_sql(self, expression): + args = [ + self.binary(e, ":=") if isinstance(e, exp.EQ) else self.sql(e) + for e in expression.expressions + ] + return f"STRUCT_PACK({', '.join(args)})" + + +class DuckDB(Dialect): + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + ":=": TokenType.EQ, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "ARRAY_LENGTH": exp.ArraySize.from_arg_list, + "ARRAY_SORT": exp.SortArray.from_arg_list, + "ARRAY_REVERSE_SORT": _sort_array_reverse, + "EPOCH": exp.TimeToUnix.from_arg_list, + "EPOCH_MS": lambda args: exp.UnixToTime( + this=exp.Div( + this=list_get(args, 0), + expression=exp.Literal.number(1000), + ) + ), + "LIST_SORT": exp.SortArray.from_arg_list, + "LIST_REVERSE_SORT": _sort_array_reverse, + "LIST_VALUE": exp.Array.from_arg_list, + "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, + "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), + "STRPTIME": format_time_lambda(exp.StrToTime, "duckdb"), + "STR_SPLIT": exp.Split.from_arg_list, + "STRING_SPLIT": exp.Split.from_arg_list, + "STRING_TO_ARRAY": exp.Split.from_arg_list, + "STR_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRING_SPLIT_REGEX": exp.RegexpSplit.from_arg_list, + "STRUCT_PACK": exp.Struct.from_arg_list, + "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, + "UNNEST": exp.Explode.from_arg_list, + } + + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.Array: lambda self, e: f"LIST_VALUE({self.expressions(e, flat=True)})", + exp.ArraySize: rename_func("ARRAY_LENGTH"), + exp.ArraySort: _array_sort_sql, + exp.ArraySum: rename_func("LIST_SUM"), + exp.DateAdd: _date_add, + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", + exp.Explode: rename_func("UNNEST"), + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.RegexpLike: rename_func("REGEXP_MATCHES"), + exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), + exp.SafeDivide: no_safe_divide_sql, + exp.Split: rename_func("STR_SPLIT"), + exp.SortArray: _sort_array_sql, + exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: lambda self, e: f"EPOCH(STRPTIME({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.Struct: _struct_pack_sql, + exp.TableSample: no_tablesample_sql, + exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", + exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("EPOCH"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS TEXT), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: _ts_or_ds_add, + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToStr: lambda self, e: f"STRFTIME({_unix_to_time(self, e)}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + exp.UnixToTimeStr: lambda self, e: f"CAST({_unix_to_time(self, e)} AS TEXT)", + } + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.NVARCHAR: "TEXT", + } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py new file mode 100644 index 0000000..e3f3f39 --- /dev/null +++ b/sqlglot/dialects/hive.py @@ -0,0 +1,304 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import ( + Dialect, + approx_count_distinct_sql, + format_time_lambda, + if_sql, + no_ilike_sql, + no_recursive_cte_sql, + no_safe_divide_sql, + no_trycast_sql, + rename_func, + struct_extract_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer + + +def _parse_map(args): + keys = [] + values = [] + for i in range(0, len(args), 2): + keys.append(args[i]) + values.append(args[i + 1]) + return HiveMap( + keys=exp.Array(expressions=keys), + values=exp.Array(expressions=values), + ) + + +def _map_sql(self, expression): + keys = expression.args["keys"] + values = expression.args["values"] + + if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): + self.unsupported("Cannot convert array columns into map use SparkSQL instead.") + return f"MAP({self.sql(keys)}, {self.sql(values)})" + + args = [] + for key, value in zip(keys.expressions, values.expressions): + args.append(self.sql(key)) + args.append(self.sql(value)) + return f"MAP({csv(*args)})" + + +def _array_sort(self, expression): + if expression.expression: + self.unsupported("Hive SORT_ARRAY does not support a comparator") + return f"SORT_ARRAY({self.sql(expression, 'this')})" + + +def _property_sql(self, expression): + key = expression.name + value = self.sql(expression, "value") + return f"'{key}' = {value}" + + +def _str_to_unix(self, expression): + return f"UNIX_TIMESTAMP({csv(self.sql(expression, 'this'), _time_format(self, expression))})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format not in (Hive.time_format, Hive.date_format): + this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" + return f"CAST({this} AS DATE)" + + +def _str_to_time(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format not in (Hive.time_format, Hive.date_format): + this = f"FROM_UNIXTIME(UNIX_TIMESTAMP({this}, {time_format}))" + return f"CAST({this} AS TIMESTAMP)" + + +def _time_format(self, expression): + time_format = self.format_time(expression) + if time_format == Hive.time_format: + return None + return time_format + + +def _time_to_str(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + return f"DATE_FORMAT({this}, {time_format})" + + +def _to_date_sql(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format and time_format not in (Hive.time_format, Hive.date_format): + return f"TO_DATE({this}, {time_format})" + return f"TO_DATE({this})" + + +def _unnest_to_explode_sql(self, expression): + unnest = expression.this + if isinstance(unnest, exp.Unnest): + alias = unnest.args.get("alias") + udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode + return "".join( + self.sql( + exp.Lateral( + this=udtf(this=expression), + alias=exp.TableAlias(this=alias.this, columns=[column]), + ) + ) + for expression, column in zip( + unnest.expressions, alias.columns if alias else [] + ) + ) + return self.join_sql(expression) + + +def _index_sql(self, expression): + this = self.sql(expression, "this") + table = self.sql(expression, "table") + columns = self.sql(expression, "columns") + return f"{this} ON TABLE {table} {columns}" + + +class HiveMap(exp.Map): + is_var_len_args = True + + +class Hive(Dialect): + alias_post_tablesample = True + + time_mapping = { + "y": "%Y", + "Y": "%Y", + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "MMM": "%b", + "MM": "%m", + "M": "%-m", + "dd": "%d", + "d": "%-d", + "HH": "%H", + "H": "%-H", + "hh": "%I", + "h": "%-I", + "mm": "%M", + "m": "%-M", + "ss": "%S", + "s": "%-S", + "S": "%f", + } + + date_format = "'yyyy-MM-dd'" + dateint_format = "'yyyyMMdd'" + time_format = "'yyyy-MM-dd HH:mm:ss'" + + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + IDENTIFIERS = ["`"] + ESCAPE = "\\" + ENCODE = "utf-8" + + NUMERIC_LITERALS = { + "L": "BIGINT", + "S": "SMALLINT", + "Y": "TINYINT", + "D": "DOUBLE", + "F": "FLOAT", + "BD": "DECIMAL", + } + + class Parser(Parser): + STRICT_CAST = False + + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, + "COLLECT_LIST": exp.ArrayAgg.from_arg_list, + "DATE_ADD": lambda args: exp.TsOrDsAdd( + this=list_get(args, 0), + expression=list_get(args, 1), + unit=exp.Literal.string("DAY"), + ), + "DATEDIFF": lambda args: exp.DateDiff( + this=exp.TsOrDsToDate(this=list_get(args, 0)), + expression=exp.TsOrDsToDate(this=list_get(args, 1)), + ), + "DATE_SUB": lambda args: exp.TsOrDsAdd( + this=list_get(args, 0), + expression=exp.Mul( + this=list_get(args, 1), + expression=exp.Literal.number(-1), + ), + unit=exp.Literal.string("DAY"), + ), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), + "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, + "LOCATE": lambda args: exp.StrPosition( + this=list_get(args, 1), + substr=list_get(args, 0), + position=list_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) + ), + "MAP": _parse_map, + "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), + "PERCENTILE": exp.Quantile.from_arg_list, + "COLLECT_SET": exp.SetAgg.from_arg_list, + "SIZE": exp.ArraySize.from_arg_list, + "SPLIT": exp.RegexpSplit.from_arg_list, + "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), + "UNIX_TIMESTAMP": format_time_lambda(exp.StrToUnix, "hive", True), + "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), + } + + class Generator(Generator): + ROOT_PROPERTIES = [ + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.LocationProperty, + exp.TableFormatProperty, + ] + WITH_PROPERTIES = [exp.AnonymousProperty] + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.AnonymousProperty: _property_sql, + exp.ApproxDistinct: approx_count_distinct_sql, + exp.ArrayAgg: rename_func("COLLECT_LIST"), + exp.ArraySize: rename_func("SIZE"), + exp.ArraySort: _array_sort, + exp.With: no_recursive_cte_sql, + exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateDiff: lambda self, e: f"DATEDIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", + exp.FileFormatProperty: lambda self, e: f"STORED AS {e.text('value').upper()}", + exp.If: if_sql, + exp.Index: _index_sql, + exp.ILike: no_ilike_sql, + exp.Join: _unnest_to_explode_sql, + exp.JSONExtract: rename_func("GET_JSON_OBJECT"), + exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), + exp.Map: _map_sql, + HiveMap: _map_sql, + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e.args['value'])}", + exp.Quantile: rename_func("PERCENTILE"), + exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), + exp.RegexpSplit: rename_func("SPLIT"), + exp.SafeDivide: no_safe_divide_sql, + exp.SchemaCommentProperty: lambda self, e: f"COMMENT {self.sql(e.args['value'])}", + exp.SetAgg: rename_func("COLLECT_SET"), + exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))", + exp.StrPosition: lambda self, e: f"LOCATE({csv(self.sql(e, 'substr'), self.sql(e, 'this'), self.sql(e, 'position'))})", + exp.StrToDate: _str_to_date, + exp.StrToTime: _str_to_time, + exp.StrToUnix: _str_to_unix, + exp.StructExtract: struct_extract_sql, + exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'value')}", + exp.TimeStrToDate: rename_func("TO_DATE"), + exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TimeToStr: _time_to_str, + exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS STRING), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.TsOrDsToDate: _to_date_sql, + exp.TryCast: no_trycast_sql, + exp.UnixToStr: lambda self, e: f"FROM_UNIXTIME({csv(self.sql(e, 'this'), _time_format(self, e))})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), + } + + def with_properties(self, properties): + return self.properties( + properties, + prefix="TBLPROPERTIES", + ) + + def datatype_sql(self, expression): + if ( + expression.this + in (exp.DataType.Type.VARCHAR, exp.DataType.Type.NVARCHAR) + and not expression.expressions + ): + expression = exp.DataType.build("text") + return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py new file mode 100644 index 0000000..93800a6 --- /dev/null +++ b/sqlglot/dialects/mysql.py @@ -0,0 +1,163 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + no_ilike_sql, + no_paren_current_date_sql, + no_tablesample_sql, + no_trycast_sql, +) +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_trunc_sql(self, expression): + unit = expression.text("unit").lower() + + this = self.sql(expression.this) + + if unit == "day": + return f"DATE({this})" + + if unit == "week": + concat = f"CONCAT(YEAR({this}), ' ', WEEK({this}, 1), ' 1')" + date_format = "%Y %u %w" + elif unit == "month": + concat = f"CONCAT(YEAR({this}), ' ', MONTH({this}), ' 1')" + date_format = "%Y %c %e" + elif unit == "quarter": + concat = f"CONCAT(YEAR({this}), ' ', QUARTER({this}) * 3 - 2, ' 1')" + date_format = "%Y %c %e" + elif unit == "year": + concat = f"CONCAT(YEAR({this}), ' 1 1')" + date_format = "%Y %c %e" + else: + self.unsupported("Unexpected interval unit: {unit}") + return f"DATE({this})" + + return f"STR_TO_DATE({concat}, '{date_format}')" + + +def _str_to_date(args): + date_format = MySQL.format_time(list_get(args, 1)) + return exp.StrToDate(this=list_get(args, 0), format=date_format) + + +def _str_to_date_sql(self, expression): + date_format = self.format_time(expression) + return f"STR_TO_DATE({self.sql(expression.this)}, {date_format})" + + +def _date_add(expression_class): + def func(args): + interval = list_get(args, 1) + return expression_class( + this=list_get(args, 0), + expression=interval.this, + unit=exp.Literal.string(interval.text("unit").lower()), + ) + + return func + + +def _date_add_sql(kind): + def func(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() or "DAY" + expression = self.sql(expression, "expression") + return f"DATE_{kind}({this}, INTERVAL {expression} {unit})" + + return func + + +class MySQL(Dialect): + # https://prestodb.io/docs/current/functions/datetime.html#mysql-date-functions + time_mapping = { + "%M": "%B", + "%c": "%-m", + "%e": "%-d", + "%h": "%I", + "%i": "%M", + "%s": "%S", + "%S": "%S", + "%u": "%W", + } + + class Tokenizer(Tokenizer): + QUOTES = ["'", '"'] + COMMENTS = ["--", "#", ("/*", "*/")] + IDENTIFIERS = ["`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "_ARMSCII8": TokenType.INTRODUCER, + "_ASCII": TokenType.INTRODUCER, + "_BIG5": TokenType.INTRODUCER, + "_BINARY": TokenType.INTRODUCER, + "_CP1250": TokenType.INTRODUCER, + "_CP1251": TokenType.INTRODUCER, + "_CP1256": TokenType.INTRODUCER, + "_CP1257": TokenType.INTRODUCER, + "_CP850": TokenType.INTRODUCER, + "_CP852": TokenType.INTRODUCER, + "_CP866": TokenType.INTRODUCER, + "_CP932": TokenType.INTRODUCER, + "_DEC8": TokenType.INTRODUCER, + "_EUCJPMS": TokenType.INTRODUCER, + "_EUCKR": TokenType.INTRODUCER, + "_GB18030": TokenType.INTRODUCER, + "_GB2312": TokenType.INTRODUCER, + "_GBK": TokenType.INTRODUCER, + "_GEOSTD8": TokenType.INTRODUCER, + "_GREEK": TokenType.INTRODUCER, + "_HEBREW": TokenType.INTRODUCER, + "_HP8": TokenType.INTRODUCER, + "_KEYBCS2": TokenType.INTRODUCER, + "_KOI8R": TokenType.INTRODUCER, + "_KOI8U": TokenType.INTRODUCER, + "_LATIN1": TokenType.INTRODUCER, + "_LATIN2": TokenType.INTRODUCER, + "_LATIN5": TokenType.INTRODUCER, + "_LATIN7": TokenType.INTRODUCER, + "_MACCE": TokenType.INTRODUCER, + "_MACROMAN": TokenType.INTRODUCER, + "_SJIS": TokenType.INTRODUCER, + "_SWE7": TokenType.INTRODUCER, + "_TIS620": TokenType.INTRODUCER, + "_UCS2": TokenType.INTRODUCER, + "_UJIS": TokenType.INTRODUCER, + "_UTF8": TokenType.INTRODUCER, + "_UTF16": TokenType.INTRODUCER, + "_UTF16LE": TokenType.INTRODUCER, + "_UTF32": TokenType.INTRODUCER, + "_UTF8MB3": TokenType.INTRODUCER, + "_UTF8MB4": TokenType.INTRODUCER, + } + + class Parser(Parser): + STRICT_CAST = False + + FUNCTIONS = { + **Parser.FUNCTIONS, + "DATE_ADD": _date_add(exp.DateAdd), + "DATE_SUB": _date_add(exp.DateSub), + "STR_TO_DATE": _str_to_date, + } + + class Generator(Generator): + NULL_ORDERING_SUPPORTED = False + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.ILike: no_ilike_sql, + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + exp.DateAdd: _date_add_sql("ADD"), + exp.DateSub: _date_add_sql("SUB"), + exp.DateTrunc: _date_trunc_sql, + exp.StrToDate: _str_to_date_sql, + exp.StrToTime: _str_to_date_sql, + } diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py new file mode 100644 index 0000000..9c8b6f2 --- /dev/null +++ b/sqlglot/dialects/oracle.py @@ -0,0 +1,63 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import Dialect, no_ilike_sql +from sqlglot.generator import Generator +from sqlglot.helper import csv +from sqlglot.tokens import Tokenizer, TokenType + + +def _limit_sql(self, expression): + return self.fetch_sql(exp.Fetch(direction="FIRST", count=expression.expression)) + + +class Oracle(Dialect): + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "NUMBER", + exp.DataType.Type.SMALLINT: "NUMBER", + exp.DataType.Type.INT: "NUMBER", + exp.DataType.Type.BIGINT: "NUMBER", + exp.DataType.Type.DECIMAL: "NUMBER", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", + exp.DataType.Type.VARCHAR: "VARCHAR2", + exp.DataType.Type.NVARCHAR: "NVARCHAR2", + exp.DataType.Type.TEXT: "CLOB", + exp.DataType.Type.BINARY: "BLOB", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.ILike: no_ilike_sql, + exp.Limit: _limit_sql, + } + + def query_modifiers(self, expression, *sqls): + return csv( + *sqls, + *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins", [])], + self.sql(expression, "where"), + self.sql(expression, "group"), + self.sql(expression, "having"), + self.sql(expression, "qualify"), + self.sql(expression, "window"), + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + self.sql(expression, "order"), + self.sql(expression, "offset"), # offset before limit in oracle + self.sql(expression, "limit"), + sep="", + ) + + def offset_sql(self, expression): + return f"{super().offset_sql(expression)} ROWS" + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "TOP": TokenType.TOP, + "VARCHAR2": TokenType.VARCHAR, + "NVARCHAR2": TokenType.NVARCHAR, + } diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py new file mode 100644 index 0000000..61dff86 --- /dev/null +++ b/sqlglot/dialects/postgres.py @@ -0,0 +1,109 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + format_time_lambda, + no_paren_current_date_sql, + no_tablesample_sql, + no_trycast_sql, +) +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _date_add_sql(kind): + def func(self, expression): + from sqlglot.optimizer.simplify import simplify + + this = self.sql(expression, "this") + unit = self.sql(expression, "unit") + expression = simplify(expression.args["expression"]) + + if not isinstance(expression, exp.Literal): + self.unsupported("Cannot add non literal") + + expression = expression.copy() + expression.args["is_string"] = True + expression = self.sql(expression) + return f"{this} {kind} INTERVAL {expression} {unit}" + + return func + + +class Postgres(Dialect): + null_ordering = "nulls_are_large" + time_format = "'YYYY-MM-DD HH24:MI:SS'" + time_mapping = { + "AM": "%p", # AM or PM + "D": "%w", # 1-based day of week + "DD": "%d", # day of month + "DDD": "%j", # zero padded day of year + "FMDD": "%-d", # - is no leading zero for Python; same for FM in postgres + "FMDDD": "%-j", # day of year + "FMHH12": "%-I", # 9 + "FMHH24": "%-H", # 9 + "FMMI": "%-M", # Minute + "FMMM": "%-m", # 1 + "FMSS": "%-S", # Second + "HH12": "%I", # 09 + "HH24": "%H", # 09 + "MI": "%M", # zero padded minute + "MM": "%m", # 01 + "OF": "%z", # utc offset + "SS": "%S", # zero padded second + "TMDay": "%A", # TM is locale dependent + "TMDy": "%a", + "TMMon": "%b", # Sep + "TMMonth": "%B", # September + "TZ": "%Z", # uppercase timezone name + "US": "%f", # zero padded microsecond + "WW": "%U", # 1-based week of year + "YY": "%y", # 15 + "YYYY": "%Y", # 2015 + } + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "SERIAL": TokenType.AUTO_INCREMENT, + "UUID": TokenType.UUID, + } + + class Parser(Parser): + STRICT_CAST = False + FUNCTIONS = { + **Parser.FUNCTIONS, + "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), + "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "SMALLINT", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", + exp.DataType.Type.BINARY: "BYTEA", + } + + TOKEN_MAPPING = { + TokenType.AUTO_INCREMENT: "SERIAL", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", + exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", + exp.CurrentDate: no_paren_current_date_sql, + exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", + exp.DateAdd: _date_add_sql("+"), + exp.DateSub: _date_add_sql("-"), + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py new file mode 100644 index 0000000..ca913e4 --- /dev/null +++ b/sqlglot/dialects/presto.py @@ -0,0 +1,216 @@ +from sqlglot import exp, transforms +from sqlglot.dialects.dialect import ( + Dialect, + format_time_lambda, + if_sql, + no_ilike_sql, + no_safe_divide_sql, + rename_func, + str_position_sql, + struct_extract_sql, +) +from sqlglot.dialects.mysql import MySQL +from sqlglot.generator import Generator +from sqlglot.helper import csv, list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _approx_distinct_sql(self, expression): + accuracy = expression.args.get("accuracy") + accuracy = ", " + self.sql(accuracy) if accuracy else "" + return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})" + + +def _concat_ws_sql(self, expression): + sep, *args = expression.expressions + sep = self.sql(sep) + if len(args) > 1: + return f"ARRAY_JOIN(ARRAY[{csv(*(self.sql(e) for e in args))}], {sep})" + return f"ARRAY_JOIN({self.sql(args[0])}, {sep})" + + +def _datatype_sql(self, expression): + sql = self.datatype_sql(expression) + if expression.this == exp.DataType.Type.TIMESTAMPTZ: + sql = f"{sql} WITH TIME ZONE" + return sql + + +def _date_parse_sql(self, expression): + return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')" + + +def _explode_to_unnest_sql(self, expression): + if isinstance(expression.this, (exp.Explode, exp.Posexplode)): + return self.sql( + exp.Join( + this=exp.Unnest( + expressions=[expression.this.this], + alias=expression.args.get("alias"), + ordinality=isinstance(expression.this, exp.Posexplode), + ), + kind="cross", + ) + ) + return self.lateral_sql(expression) + + +def _initcap_sql(self, expression): + regex = r"(\w)(\w*)" + return f"REGEXP_REPLACE({self.sql(expression, 'this')}, '{regex}', x -> UPPER(x[1]) || LOWER(x[2]))" + + +def _no_sort_array(self, expression): + if expression.args.get("asc") == exp.FALSE: + comparator = "(a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END" + else: + comparator = None + args = csv(self.sql(expression, "this"), comparator) + return f"ARRAY_SORT({args})" + + +def _schema_sql(self, expression): + if isinstance(expression.parent, exp.Property): + columns = ", ".join(f"'{c.text('this')}'" for c in expression.expressions) + return f"ARRAY[{columns}]" + + for schema in expression.parent.find_all(exp.Schema): + if isinstance(schema.parent, exp.Property): + expression = expression.copy() + expression.expressions.extend(schema.expressions) + + return self.schema_sql(expression) + + +def _quantile_sql(self, expression): + self.unsupported("Presto does not support exact quantiles") + return f"APPROX_PERCENTILE({self.sql(expression, 'this')}, {self.sql(expression, 'quantile')})" + + +def _str_to_time_sql(self, expression): + return f"DATE_PARSE({self.sql(expression, 'this')}, {self.format_time(expression)})" + + +def _ts_or_ds_to_date_sql(self, expression): + time_format = self.format_time(expression) + if time_format and time_format not in (Presto.time_format, Presto.date_format): + return f"CAST({_str_to_time_sql(self, expression)} AS DATE)" + return ( + f"CAST(SUBSTR(CAST({self.sql(expression, 'this')} AS VARCHAR), 1, 10) AS DATE)" + ) + + +def _ts_or_ds_add_sql(self, expression): + this = self.sql(expression, "this") + e = self.sql(expression, "expression") + unit = self.sql(expression, "unit") or "'day'" + return f"DATE_ADD({unit}, {e}, DATE_PARSE(SUBSTR({this}, 1, 10), {Presto.date_format}))" + + +class Presto(Dialect): + index_offset = 1 + null_ordering = "nulls_are_last" + time_format = "'%Y-%m-%d %H:%i:%S'" + time_mapping = MySQL.time_mapping + + class Tokenizer(Tokenizer): + KEYWORDS = { + **Tokenizer.KEYWORDS, + "ROW": TokenType.STRUCT, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, + "CARDINALITY": exp.ArraySize.from_arg_list, + "CONTAINS": exp.ArrayContains.from_arg_list, + "DATE_ADD": lambda args: exp.DateAdd( + this=list_get(args, 2), + expression=list_get(args, 1), + unit=list_get(args, 0), + ), + "DATE_DIFF": lambda args: exp.DateDiff( + this=list_get(args, 2), + expression=list_get(args, 1), + unit=list_get(args, 0), + ), + "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"), + "DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"), + "FROM_UNIXTIME": exp.UnixToTime.from_arg_list, + "STRPOS": exp.StrPosition.from_arg_list, + "TO_UNIXTIME": exp.TimeToUnix.from_arg_list, + } + + class Generator(Generator): + + STRUCT_DELIMITER = ("(", ")") + + WITH_PROPERTIES = [ + exp.PartitionedByProperty, + exp.FileFormatProperty, + exp.SchemaCommentProperty, + exp.AnonymousProperty, + exp.TableFormatProperty, + ] + + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.BINARY: "VARBINARY", + exp.DataType.Type.TEXT: "VARCHAR", + exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP", + exp.DataType.Type.STRUCT: "ROW", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, + exp.ApproxDistinct: _approx_distinct_sql, + exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.ArrayContains: rename_func("CONTAINS"), + exp.ArraySize: rename_func("CARDINALITY"), + exp.BitwiseAnd: lambda self, e: f"BITWISE_AND({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseLeftShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_LEFT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseNot: lambda self, e: f"BITWISE_NOT({self.sql(e, 'this')})", + exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.ConcatWs: _concat_ws_sql, + exp.DataType: _datatype_sql, + exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""", + exp.DateStrToDate: lambda self, e: f"CAST(DATE_PARSE({self.sql(e, 'this')}, {Presto.date_format}) AS DATE)", + exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Presto.dateint_format}) AS INT)", + exp.DiToDate: lambda self, e: f"CAST(DATE_PARSE(CAST({self.sql(e, 'this')} AS VARCHAR), {Presto.dateint_format}) AS DATE)", + exp.FileFormatProperty: lambda self, e: self.property_sql(e), + exp.If: if_sql, + exp.ILike: no_ilike_sql, + exp.Initcap: _initcap_sql, + exp.Lateral: _explode_to_unnest_sql, + exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"), + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY = {self.sql(e.args['value'])}", + exp.Quantile: _quantile_sql, + exp.SafeDivide: no_safe_divide_sql, + exp.Schema: _schema_sql, + exp.SortArray: _no_sort_array, + exp.StrPosition: str_position_sql, + exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)", + exp.StrToTime: _str_to_time_sql, + exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))", + exp.StructExtract: struct_extract_sql, + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT = '{e.text('value').upper()}'", + exp.TimeStrToDate: _date_parse_sql, + exp.TimeStrToTime: _date_parse_sql, + exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", + exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeToUnix: rename_func("TO_UNIXTIME"), + exp.TsOrDiToDi: lambda self, e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)", + exp.TsOrDsAdd: _ts_or_ds_add_sql, + exp.TsOrDsToDate: _ts_or_ds_to_date_sql, + exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})", + exp.UnixToTime: rename_func("FROM_UNIXTIME"), + exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)", + } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py new file mode 100644 index 0000000..148dfb5 --- /dev/null +++ b/sqlglot/dialects/snowflake.py @@ -0,0 +1,145 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect, format_time_lambda, rename_func +from sqlglot.expressions import Literal +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +def _check_int(s): + if s[0] in ("-", "+"): + return s[1:].isdigit() + return s.isdigit() + + +# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html +def _snowflake_to_timestamp(args): + if len(args) == 2: + first_arg, second_arg = args + if second_arg.is_string: + # case: <string_expr> [ , <format> ] + return format_time_lambda(exp.StrToTime, "snowflake")(args) + + # case: <numeric_expr> [ , <scale> ] + if second_arg.name not in ["0", "3", "9"]: + raise ValueError( + f"Scale for snowflake numeric timestamp is {second_arg}, but should be 0, 3, or 9" + ) + + if second_arg.name == "0": + timescale = exp.UnixToTime.SECONDS + elif second_arg.name == "3": + timescale = exp.UnixToTime.MILLIS + elif second_arg.name == "9": + timescale = exp.UnixToTime.MICROS + + return exp.UnixToTime(this=first_arg, scale=timescale) + + first_arg = list_get(args, 0) + if not isinstance(first_arg, Literal): + # case: <variant_expr> + return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + + if first_arg.is_string: + if _check_int(first_arg.this): + # case: <integer> + return exp.UnixToTime.from_arg_list(args) + + # case: <date_expr> + return format_time_lambda(exp.StrToTime, "snowflake", default=True)(args) + + # case: <numeric_expr> + return exp.UnixToTime.from_arg_list(args) + + +def _unix_to_time(self, expression): + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale in [None, exp.UnixToTime.SECONDS]: + return f"TO_TIMESTAMP({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TO_TIMESTAMP({timestamp}, 3)" + if scale == exp.UnixToTime.MICROS: + return f"TO_TIMESTAMP({timestamp}, 9)" + + raise ValueError("Improper scale for timestamp") + + +class Snowflake(Dialect): + null_ordering = "nulls_are_large" + time_format = "'yyyy-mm-dd hh24:mi:ss'" + + time_mapping = { + "YYYY": "%Y", + "yyyy": "%Y", + "YY": "%y", + "yy": "%y", + "MMMM": "%B", + "mmmm": "%B", + "MON": "%b", + "mon": "%b", + "MM": "%m", + "mm": "%m", + "DD": "%d", + "dd": "%d", + "d": "%-d", + "DY": "%w", + "dy": "%w", + "HH24": "%H", + "hh24": "%H", + "HH12": "%I", + "hh12": "%I", + "MI": "%M", + "mi": "%M", + "SS": "%S", + "ss": "%S", + "FF": "%f", + "ff": "%f", + "FF6": "%f", + "ff6": "%f", + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "ARRAYAGG": exp.ArrayAgg.from_arg_list, + "IFF": exp.If.from_arg_list, + "TO_TIMESTAMP": _snowflake_to_timestamp, + } + + COLUMN_OPERATORS = { + **Parser.COLUMN_OPERATORS, + TokenType.COLON: lambda self, this, path: self.expression( + exp.Bracket, + this=this, + expressions=[path], + ), + } + + class Tokenizer(Tokenizer): + QUOTES = ["'", "$$"] + ESCAPE = "\\" + KEYWORDS = { + **Tokenizer.KEYWORDS, + "QUALIFY": TokenType.QUALIFY, + "DOUBLE PRECISION": TokenType.DOUBLE, + } + + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.If: rename_func("IFF"), + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + } + + def except_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("EXCEPT with All is not supported in Snowflake") + return super().except_op(expression) + + def intersect_op(self, expression): + if not expression.args.get("distinct", False): + self.unsupported("INTERSECT with All is not supported in Snowflake") + return super().intersect_op(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py new file mode 100644 index 0000000..89c7ed5 --- /dev/null +++ b/sqlglot/dialects/spark.py @@ -0,0 +1,106 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import no_ilike_sql, rename_func +from sqlglot.dialects.hive import Hive, HiveMap +from sqlglot.helper import list_get + + +def _create_sql(self, e): + kind = e.args.get("kind") + temporary = e.args.get("temporary") + + if kind.upper() == "TABLE" and temporary is True: + return f"CREATE TEMPORARY VIEW {self.sql(e, 'this')} AS {self.sql(e, 'expression')}" + return self.create_sql(e) + + +def _map_sql(self, expression): + keys = self.sql(expression.args["keys"]) + values = self.sql(expression.args["values"]) + return f"MAP_FROM_ARRAYS({keys}, {values})" + + +def _str_to_date(self, expression): + this = self.sql(expression, "this") + time_format = self.format_time(expression) + if time_format == Hive.date_format: + return f"TO_DATE({this})" + return f"TO_DATE({this}, {time_format})" + + +def _unix_to_time(self, expression): + scale = expression.args.get("scale") + timestamp = self.sql(expression, "this") + if scale is None: + return f"FROM_UNIXTIME({timestamp})" + if scale == exp.UnixToTime.SECONDS: + return f"TIMESTAMP_SECONDS({timestamp})" + if scale == exp.UnixToTime.MILLIS: + return f"TIMESTAMP_MILLIS({timestamp})" + if scale == exp.UnixToTime.MICROS: + return f"TIMESTAMP_MICROS({timestamp})" + + raise ValueError("Improper scale for timestamp") + + +class Spark(Hive): + class Parser(Hive.Parser): + FUNCTIONS = { + **Hive.Parser.FUNCTIONS, + "MAP_FROM_ARRAYS": exp.Map.from_arg_list, + "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, + "LEFT": lambda args: exp.Substring( + this=list_get(args, 0), + start=exp.Literal.number(1), + length=list_get(args, 1), + ), + "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( + this=list_get(args, 0), + expression=list_get(args, 1), + ), + "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( + this=list_get(args, 0), + expression=list_get(args, 1), + ), + "RIGHT": lambda args: exp.Substring( + this=list_get(args, 0), + start=exp.Sub( + this=exp.Length(this=list_get(args, 0)), + expression=exp.Add( + this=list_get(args, 1), expression=exp.Literal.number(1) + ), + ), + length=list_get(args, 1), + ), + } + + class Generator(Hive.Generator): + TYPE_MAPPING = { + **Hive.Generator.TYPE_MAPPING, + exp.DataType.Type.TINYINT: "BYTE", + exp.DataType.Type.SMALLINT: "SHORT", + exp.DataType.Type.BIGINT: "LONG", + } + + TRANSFORMS = { + **{ + k: v + for k, v in Hive.Generator.TRANSFORMS.items() + if k not in {exp.ArraySort} + }, + exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + exp.BitwiseLeftShift: rename_func("SHIFTLEFT"), + exp.BitwiseRightShift: rename_func("SHIFTRIGHT"), + exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", + exp.ILike: no_ilike_sql, + exp.StrToDate: _str_to_date, + exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.UnixToTime: _unix_to_time, + exp.Create: _create_sql, + exp.Map: _map_sql, + exp.Reduce: rename_func("AGGREGATE"), + exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}", + HiveMap: _map_sql, + } + + def bitstring_sql(self, expression): + return f"X'{self.sql(expression, 'this')}'" diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py new file mode 100644 index 0000000..6cf5022 --- /dev/null +++ b/sqlglot/dialects/sqlite.py @@ -0,0 +1,63 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import ( + Dialect, + arrow_json_extract_scalar_sql, + arrow_json_extract_sql, + no_ilike_sql, + no_tablesample_sql, + no_trycast_sql, + rename_func, +) +from sqlglot.generator import Generator +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer, TokenType + + +class SQLite(Dialect): + class Tokenizer(Tokenizer): + IDENTIFIERS = ['"', ("[", "]"), "`"] + + KEYWORDS = { + **Tokenizer.KEYWORDS, + "AUTOINCREMENT": TokenType.AUTO_INCREMENT, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "EDITDIST3": exp.Levenshtein.from_arg_list, + } + + class Generator(Generator): + TYPE_MAPPING = { + **Generator.TYPE_MAPPING, + exp.DataType.Type.BOOLEAN: "INTEGER", + exp.DataType.Type.TINYINT: "INTEGER", + exp.DataType.Type.SMALLINT: "INTEGER", + exp.DataType.Type.INT: "INTEGER", + exp.DataType.Type.BIGINT: "INTEGER", + exp.DataType.Type.FLOAT: "REAL", + exp.DataType.Type.DOUBLE: "REAL", + exp.DataType.Type.DECIMAL: "REAL", + exp.DataType.Type.CHAR: "TEXT", + exp.DataType.Type.NCHAR: "TEXT", + exp.DataType.Type.VARCHAR: "TEXT", + exp.DataType.Type.NVARCHAR: "TEXT", + exp.DataType.Type.BINARY: "BLOB", + } + + TOKEN_MAPPING = { + TokenType.AUTO_INCREMENT: "AUTOINCREMENT", + } + + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.ILike: no_ilike_sql, + exp.JSONExtract: arrow_json_extract_sql, + exp.JSONExtractScalar: arrow_json_extract_scalar_sql, + exp.JSONBExtract: arrow_json_extract_sql, + exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.Levenshtein: rename_func("EDITDIST3"), + exp.TableSample: no_tablesample_sql, + exp.TryCast: no_trycast_sql, + } diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py new file mode 100644 index 0000000..b9cd584 --- /dev/null +++ b/sqlglot/dialects/starrocks.py @@ -0,0 +1,12 @@ +from sqlglot import exp +from sqlglot.dialects.mysql import MySQL + + +class StarRocks(MySQL): + class Generator(MySQL.Generator): + TYPE_MAPPING = { + **MySQL.Generator.TYPE_MAPPING, + exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.TIMESTAMP: "DATETIME", + exp.DataType.Type.TIMESTAMPTZ: "DATETIME", + } diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py new file mode 100644 index 0000000..e571749 --- /dev/null +++ b/sqlglot/dialects/tableau.py @@ -0,0 +1,37 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import Dialect +from sqlglot.generator import Generator +from sqlglot.helper import list_get +from sqlglot.parser import Parser + + +def _if_sql(self, expression): + return f"IF {self.sql(expression, 'this')} THEN {self.sql(expression, 'true')} ELSE {self.sql(expression, 'false')} END" + + +def _coalesce_sql(self, expression): + return f"IFNULL({self.sql(expression, 'this')}, {self.expressions(expression)})" + + +def _count_sql(self, expression): + this = expression.this + if isinstance(this, exp.Distinct): + return f"COUNTD({self.sql(this, 'this')})" + return f"COUNT({self.sql(expression, 'this')})" + + +class Tableau(Dialect): + class Generator(Generator): + TRANSFORMS = { + **Generator.TRANSFORMS, + exp.If: _if_sql, + exp.Coalesce: _coalesce_sql, + exp.Count: _count_sql, + } + + class Parser(Parser): + FUNCTIONS = { + **Parser.FUNCTIONS, + "IFNULL": exp.Coalesce.from_arg_list, + "COUNTD": lambda args: exp.Count(this=exp.Distinct(this=list_get(args, 0))), + } diff --git a/sqlglot/dialects/trino.py b/sqlglot/dialects/trino.py new file mode 100644 index 0000000..805106c --- /dev/null +++ b/sqlglot/dialects/trino.py @@ -0,0 +1,10 @@ +from sqlglot import exp +from sqlglot.dialects.presto import Presto + + +class Trino(Presto): + class Generator(Presto.Generator): + TRANSFORMS = { + **Presto.Generator.TRANSFORMS, + exp.ArraySum: lambda self, e: f"REDUCE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", + } diff --git a/sqlglot/diff.py b/sqlglot/diff.py new file mode 100644 index 0000000..8eeb4e9 --- /dev/null +++ b/sqlglot/diff.py @@ -0,0 +1,314 @@ +from collections import defaultdict +from dataclasses import dataclass +from heapq import heappop, heappush + +from sqlglot import Dialect +from sqlglot import expressions as exp +from sqlglot.helper import ensure_list + + +@dataclass(frozen=True) +class Insert: + """Indicates that a new node has been inserted""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Remove: + """Indicates that an existing node has been removed""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Move: + """Indicates that an existing node's position within the tree has changed""" + + expression: exp.Expression + + +@dataclass(frozen=True) +class Update: + """Indicates that an existing node has been updated""" + + source: exp.Expression + target: exp.Expression + + +@dataclass(frozen=True) +class Keep: + """Indicates that an existing node hasn't been changed""" + + source: exp.Expression + target: exp.Expression + + +def diff(source, target): + """ + Returns the list of changes between the source and the target expressions. + + Examples: + >>> diff(parse_one("a + b"), parse_one("a + c")) + [ + Remove(expression=(COLUMN this: (IDENTIFIER this: b, quoted: False))), + Insert(expression=(COLUMN this: (IDENTIFIER this: c, quoted: False))), + Keep( + source=(ADD this: ...), + target=(ADD this: ...) + ), + Keep( + source=(COLUMN this: (IDENTIFIER this: a, quoted: False)), + target=(COLUMN this: (IDENTIFIER this: a, quoted: False)) + ), + ] + + Args: + source (sqlglot.Expression): the source expression. + target (sqlglot.Expression): the target expression against which the diff should be calculated. + + Returns: + the list of Insert, Remove, Move, Update and Keep objects for each node in the source and the target expression trees. + This list represents a sequence of steps needed to transform the source expression tree into the target one. + """ + return ChangeDistiller().diff(source.copy(), target.copy()) + + +LEAF_EXPRESSION_TYPES = ( + exp.Boolean, + exp.DataType, + exp.Identifier, + exp.Literal, +) + + +class ChangeDistiller: + """ + The implementation of the Change Distiller algorithm described by Beat Fluri and Martin Pinzger in + their paper https://ieeexplore.ieee.org/document/4339230, which in turn is based on the algorithm by + Chawathe et al. described in http://ilpubs.stanford.edu:8090/115/1/1995-46.pdf. + """ + + def __init__(self, f=0.6, t=0.6): + self.f = f + self.t = t + self._sql_generator = Dialect().generator() + + def diff(self, source, target): + self._source = source + self._target = target + self._source_index = {id(n[0]): n[0] for n in source.bfs()} + self._target_index = {id(n[0]): n[0] for n in target.bfs()} + self._unmatched_source_nodes = set(self._source_index) + self._unmatched_target_nodes = set(self._target_index) + self._bigram_histo_cache = {} + + matching_set = self._compute_matching_set() + return self._generate_edit_script(matching_set) + + def _generate_edit_script(self, matching_set): + edit_script = [] + for removed_node_id in self._unmatched_source_nodes: + edit_script.append(Remove(self._source_index[removed_node_id])) + for inserted_node_id in self._unmatched_target_nodes: + edit_script.append(Insert(self._target_index[inserted_node_id])) + for kept_source_node_id, kept_target_node_id in matching_set: + source_node = self._source_index[kept_source_node_id] + target_node = self._target_index[kept_target_node_id] + if ( + not isinstance(source_node, LEAF_EXPRESSION_TYPES) + or source_node == target_node + ): + edit_script.extend( + self._generate_move_edits(source_node, target_node, matching_set) + ) + edit_script.append(Keep(source_node, target_node)) + else: + edit_script.append(Update(source_node, target_node)) + + return edit_script + + def _generate_move_edits(self, source, target, matching_set): + source_args = [id(e) for e in _expression_only_args(source)] + target_args = [id(e) for e in _expression_only_args(target)] + + args_lcs = set( + _lcs(source_args, target_args, lambda l, r: (l, r) in matching_set) + ) + + move_edits = [] + for a in source_args: + if a not in args_lcs and a not in self._unmatched_source_nodes: + move_edits.append(Move(self._source_index[a])) + + return move_edits + + def _compute_matching_set(self): + leaves_matching_set = self._compute_leaf_matching_set() + matching_set = leaves_matching_set.copy() + + ordered_unmatched_source_nodes = { + id(n[0]): None + for n in self._source.bfs() + if id(n[0]) in self._unmatched_source_nodes + } + ordered_unmatched_target_nodes = { + id(n[0]): None + for n in self._target.bfs() + if id(n[0]) in self._unmatched_target_nodes + } + + for source_node_id in ordered_unmatched_source_nodes: + for target_node_id in ordered_unmatched_target_nodes: + source_node = self._source_index[source_node_id] + target_node = self._target_index[target_node_id] + if _is_same_type(source_node, target_node): + source_leaf_ids = {id(l) for l in _get_leaves(source_node)} + target_leaf_ids = {id(l) for l in _get_leaves(target_node)} + + max_leaves_num = max(len(source_leaf_ids), len(target_leaf_ids)) + if max_leaves_num: + common_leaves_num = sum( + 1 if s in source_leaf_ids and t in target_leaf_ids else 0 + for s, t in leaves_matching_set + ) + leaf_similarity_score = common_leaves_num / max_leaves_num + else: + leaf_similarity_score = 0.0 + + adjusted_t = ( + self.t + if min(len(source_leaf_ids), len(target_leaf_ids)) > 4 + else 0.4 + ) + + if leaf_similarity_score >= 0.8 or ( + leaf_similarity_score >= adjusted_t + and self._dice_coefficient(source_node, target_node) >= self.f + ): + matching_set.add((source_node_id, target_node_id)) + self._unmatched_source_nodes.remove(source_node_id) + self._unmatched_target_nodes.remove(target_node_id) + ordered_unmatched_target_nodes.pop(target_node_id, None) + break + + return matching_set + + def _compute_leaf_matching_set(self): + candidate_matchings = [] + source_leaves = list(_get_leaves(self._source)) + target_leaves = list(_get_leaves(self._target)) + for source_leaf in source_leaves: + for target_leaf in target_leaves: + if _is_same_type(source_leaf, target_leaf): + similarity_score = self._dice_coefficient(source_leaf, target_leaf) + if similarity_score >= self.f: + heappush( + candidate_matchings, + ( + -similarity_score, + len(candidate_matchings), + source_leaf, + target_leaf, + ), + ) + + # Pick best matchings based on the highest score + matching_set = set() + while candidate_matchings: + _, _, source_leaf, target_leaf = heappop(candidate_matchings) + if ( + id(source_leaf) in self._unmatched_source_nodes + and id(target_leaf) in self._unmatched_target_nodes + ): + matching_set.add((id(source_leaf), id(target_leaf))) + self._unmatched_source_nodes.remove(id(source_leaf)) + self._unmatched_target_nodes.remove(id(target_leaf)) + + return matching_set + + def _dice_coefficient(self, source, target): + source_histo = self._bigram_histo(source) + target_histo = self._bigram_histo(target) + + total_grams = sum(source_histo.values()) + sum(target_histo.values()) + if not total_grams: + return 1.0 if source == target else 0.0 + + overlap_len = 0 + overlapping_grams = set(source_histo) & set(target_histo) + for g in overlapping_grams: + overlap_len += min(source_histo[g], target_histo[g]) + + return 2 * overlap_len / total_grams + + def _bigram_histo(self, expression): + if id(expression) in self._bigram_histo_cache: + return self._bigram_histo_cache[id(expression)] + + expression_str = self._sql_generator.generate(expression) + count = max(0, len(expression_str) - 1) + bigram_histo = defaultdict(int) + for i in range(count): + bigram_histo[expression_str[i : i + 2]] += 1 + + self._bigram_histo_cache[id(expression)] = bigram_histo + return bigram_histo + + +def _get_leaves(expression): + has_child_exprs = False + + for a in expression.args.values(): + nodes = ensure_list(a) + for node in nodes: + if isinstance(node, exp.Expression): + has_child_exprs = True + yield from _get_leaves(node) + + if not has_child_exprs: + yield expression + + +def _is_same_type(source, target): + if type(source) is type(target): + if isinstance(source, exp.Join): + return source.args.get("side") == target.args.get("side") + + if isinstance(source, exp.Anonymous): + return source.this == target.this + + return True + + return False + + +def _expression_only_args(expression): + args = [] + if expression: + for a in expression.args.values(): + args.extend(ensure_list(a)) + return [a for a in args if isinstance(a, exp.Expression)] + + +def _lcs(seq_a, seq_b, equal): + """Calculates the longest common subsequence""" + + len_a = len(seq_a) + len_b = len(seq_b) + lcs_result = [[None] * (len_b + 1) for i in range(len_a + 1)] + + for i in range(len_a + 1): + for j in range(len_b + 1): + if i == 0 or j == 0: + lcs_result[i][j] = [] + elif equal(seq_a[i - 1], seq_b[j - 1]): + lcs_result[i][j] = lcs_result[i - 1][j - 1] + [seq_a[i - 1]] + else: + lcs_result[i][j] = ( + lcs_result[i - 1][j] + if len(lcs_result[i - 1][j]) > len(lcs_result[i][j - 1]) + else lcs_result[i][j - 1] + ) + + return lcs_result[len_a][len_b] diff --git a/sqlglot/errors.py b/sqlglot/errors.py new file mode 100644 index 0000000..89aa935 --- /dev/null +++ b/sqlglot/errors.py @@ -0,0 +1,38 @@ +from enum import auto + +from sqlglot.helper import AutoName + + +class ErrorLevel(AutoName): + IGNORE = auto() # Ignore any parser errors + WARN = auto() # Log any parser errors with ERROR level + RAISE = auto() # Collect all parser errors and raise a single exception + IMMEDIATE = auto() # Immediately raise an exception on the first parser error + + +class SqlglotError(Exception): + pass + + +class UnsupportedError(SqlglotError): + pass + + +class ParseError(SqlglotError): + pass + + +class TokenError(SqlglotError): + pass + + +class OptimizeError(SqlglotError): + pass + + +def concat_errors(errors, maximum): + msg = [str(e) for e in errors[:maximum]] + remaining = len(errors) - maximum + if remaining > 0: + msg.append(f"... and {remaining} more") + return "\n\n".join(msg) diff --git a/sqlglot/executor/__init__.py b/sqlglot/executor/__init__.py new file mode 100644 index 0000000..a437431 --- /dev/null +++ b/sqlglot/executor/__init__.py @@ -0,0 +1,39 @@ +import logging +import time + +from sqlglot import parse_one +from sqlglot.executor.python import PythonExecutor +from sqlglot.optimizer import optimize +from sqlglot.planner import Plan + +logger = logging.getLogger("sqlglot") + + +def execute(sql, schema, read=None): + """ + Run a sql query against data. + + Args: + sql (str): a sql statement + schema (dict|sqlglot.optimizer.Schema): database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + read (str): the SQL dialect to apply during parsing + (eg. "spark", "hive", "presto", "mysql"). + Returns: + sqlglot.executor.Table: Simple columnar data structure. + """ + expression = parse_one(sql, read=read) + now = time.time() + expression = optimize(expression, schema) + logger.debug("Optimization finished: %f", time.time() - now) + logger.debug("Optimized SQL: %s", expression.sql(pretty=True)) + plan = Plan(expression) + logger.debug("Logical Plan: %s", plan) + now = time.time() + result = PythonExecutor().execute(plan) + logger.debug("Query finished: %f", time.time() - now) + return result diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py new file mode 100644 index 0000000..457bea7 --- /dev/null +++ b/sqlglot/executor/context.py @@ -0,0 +1,68 @@ +from sqlglot.executor.env import ENV + + +class Context: + """ + Execution context for sql expressions. + + Context is used to hold relevant data tables which can then be queried on with eval. + + References to columns can either be scalar or vectors. When set_row is used, column references + evaluate to scalars while set_range evaluates to vectors. This allows convenient and efficient + evaluation of aggregation functions. + """ + + def __init__(self, tables, env=None): + """ + Args + tables (dict): table_name -> Table, representing the scope of the current execution context + env (Optional[dict]): dictionary of functions within the execution context + """ + self.tables = tables + self.range_readers = { + name: table.range_reader for name, table in self.tables.items() + } + self.row_readers = {name: table.reader for name, table in tables.items()} + self.env = {**(env or {}), "scope": self.row_readers} + + def eval(self, code): + return eval(code, ENV, self.env) + + def eval_tuple(self, codes): + return tuple(self.eval(code) for code in codes) + + def __iter__(self): + return self.table_iter(list(self.tables)[0]) + + def table_iter(self, table): + self.env["scope"] = self.row_readers + + for reader in self.tables[table]: + yield reader, self + + def sort(self, table, key): + table = self.tables[table] + + def sort_key(row): + table.reader.row = row + return self.eval_tuple(key) + + table.rows.sort(key=sort_key) + + def set_row(self, table, row): + self.row_readers[table].row = row + self.env["scope"] = self.row_readers + + def set_index(self, table, index): + self.row_readers[table].row = self.tables[table].rows[index] + self.env["scope"] = self.row_readers + + def set_range(self, table, start, end): + self.range_readers[table].range = range(start, end) + self.env["scope"] = self.range_readers + + def __getitem__(self, table): + return self.env["scope"][table] + + def __contains__(self, table): + return table in self.tables diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py new file mode 100644 index 0000000..72b0558 --- /dev/null +++ b/sqlglot/executor/env.py @@ -0,0 +1,32 @@ +import datetime +import re +import statistics + + +class reverse_key: + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + return other.obj == self.obj + + def __lt__(self, other): + return other.obj < self.obj + + +ENV = { + "__builtins__": {}, + "datetime": datetime, + "locals": locals, + "re": re, + "float": float, + "int": int, + "str": str, + "desc": reverse_key, + "SUM": sum, + "AVG": statistics.fmean if hasattr(statistics, "fmean") else statistics.mean, + "COUNT": lambda acc: sum(1 for e in acc if e is not None), + "MAX": max, + "MIN": min, + "POW": pow, +} diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py new file mode 100644 index 0000000..388a419 --- /dev/null +++ b/sqlglot/executor/python.py @@ -0,0 +1,360 @@ +import ast +import collections +import itertools + +from sqlglot import exp, planner +from sqlglot.dialects.dialect import Dialect, inline_array_sql +from sqlglot.executor.context import Context +from sqlglot.executor.env import ENV +from sqlglot.executor.table import Table +from sqlglot.generator import Generator +from sqlglot.helper import csv_reader +from sqlglot.tokens import Tokenizer + + +class PythonExecutor: + def __init__(self, env=None): + self.generator = Python().generator(identify=True) + self.env = {**ENV, **(env or {})} + + def execute(self, plan): + running = set() + finished = set() + queue = set(plan.leaves) + contexts = {} + + while queue: + node = queue.pop() + context = self.context( + { + name: table + for dep in node.dependencies + for name, table in contexts[dep].tables.items() + } + ) + running.add(node) + + if isinstance(node, planner.Scan): + contexts[node] = self.scan(node, context) + elif isinstance(node, planner.Aggregate): + contexts[node] = self.aggregate(node, context) + elif isinstance(node, planner.Join): + contexts[node] = self.join(node, context) + elif isinstance(node, planner.Sort): + contexts[node] = self.sort(node, context) + else: + raise NotImplementedError + + running.remove(node) + finished.add(node) + + for dep in node.dependents: + if dep not in running and all(d in contexts for d in dep.dependencies): + queue.add(dep) + + for dep in node.dependencies: + if all(d in finished for d in dep.dependents): + contexts.pop(dep) + + root = plan.root + return contexts[root].tables[root.name] + + def generate(self, expression): + """Convert a SQL expression into literal Python code and compile it into bytecode.""" + if not expression: + return None + + sql = self.generator.generate(expression) + return compile(sql, sql, "eval", optimize=2) + + def generate_tuple(self, expressions): + """Convert an array of SQL expressions into tuple of Python byte code.""" + if not expressions: + return tuple() + return tuple(self.generate(expression) for expression in expressions) + + def context(self, tables): + return Context(tables, env=self.env) + + def table(self, expressions): + return Table(expression.alias_or_name for expression in expressions) + + def scan(self, step, context): + if hasattr(step, "source"): + source = step.source + + if isinstance(source, exp.Expression): + source = source.this.name or source.alias + else: + source = step.name + condition = self.generate(step.condition) + projections = self.generate_tuple(step.projections) + + if source in context: + if not projections and not condition: + return self.context({step.name: context.tables[source]}) + table_iter = context.table_iter(source) + else: + table_iter = self.scan_csv(step) + + if projections: + sink = self.table(step.projections) + elif source in context: + sink = Table(context[source].columns) + else: + sink = None + + for reader, ctx in table_iter: + if sink is None: + sink = Table(ctx[source].columns) + + if condition and not ctx.eval(condition): + continue + + if projections: + sink.append(ctx.eval_tuple(projections)) + else: + sink.append(reader.row) + + if len(sink) >= step.limit: + break + + return self.context({step.name: sink}) + + def scan_csv(self, step): + source = step.source + alias = source.alias + + with csv_reader(source.this) as reader: + columns = next(reader) + table = Table(columns) + context = self.context({alias: table}) + types = [] + + for row in reader: + if not types: + for v in row: + try: + types.append(type(ast.literal_eval(v))) + except (ValueError, SyntaxError): + types.append(str) + context.set_row(alias, tuple(t(v) for t, v in zip(types, row))) + yield context[alias], context + + def join(self, step, context): + source = step.name + + join_context = self.context({source: context.tables[source]}) + + def merge_context(ctx, table): + # create a new context where all existing tables are mapped to a new one + return self.context({name: table for name in ctx.tables}) + + for name, join in step.joins.items(): + join_context = self.context( + {**join_context.tables, name: context.tables[name]} + ) + + if join.get("source_key"): + table = self.hash_join(join, source, name, join_context) + else: + table = self.nested_loop_join(join, source, name, join_context) + + join_context = merge_context(join_context, table) + + # apply projections or conditions + context = self.scan(step, join_context) + + # use the scan context since it returns a single table + # otherwise there are no projections so all other tables are still in scope + if step.projections: + return context + + return merge_context(join_context, context.tables[source]) + + def nested_loop_join(self, _join, a, b, context): + table = Table(context.tables[a].columns + context.tables[b].columns) + + for reader_a, _ in context.table_iter(a): + for reader_b, _ in context.table_iter(b): + table.append(reader_a.row + reader_b.row) + + return table + + def hash_join(self, join, a, b, context): + a_key = self.generate_tuple(join["source_key"]) + b_key = self.generate_tuple(join["join_key"]) + + results = collections.defaultdict(lambda: ([], [])) + + for reader, ctx in context.table_iter(a): + results[ctx.eval_tuple(a_key)][0].append(reader.row) + for reader, ctx in context.table_iter(b): + results[ctx.eval_tuple(b_key)][1].append(reader.row) + + table = Table(context.tables[a].columns + context.tables[b].columns) + for a_group, b_group in results.values(): + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def sort_merge_join(self, join, a, b, context): + a_key = self.generate_tuple(join["source_key"]) + b_key = self.generate_tuple(join["join_key"]) + + context.sort(a, a_key) + context.sort(b, b_key) + + a_i = 0 + b_i = 0 + a_n = len(context.tables[a]) + b_n = len(context.tables[b]) + + table = Table(context.tables[a].columns + context.tables[b].columns) + + def get_key(source, key, i): + context.set_index(source, i) + return context.eval_tuple(key) + + while a_i < a_n and b_i < b_n: + key = min(get_key(a, a_key, a_i), get_key(b, b_key, b_i)) + + a_group = [] + + while a_i < a_n and key == get_key(a, a_key, a_i): + a_group.append(context[a].row) + a_i += 1 + + b_group = [] + + while b_i < b_n and key == get_key(b, b_key, b_i): + b_group.append(context[b].row) + b_i += 1 + + for a_row, b_row in itertools.product(a_group, b_group): + table.append(a_row + b_row) + + return table + + def aggregate(self, step, context): + source = step.source + group_by = self.generate_tuple(step.group) + aggregations = self.generate_tuple(step.aggregations) + operands = self.generate_tuple(step.operands) + + context.sort(source, group_by) + + if step.operands: + source_table = context.tables[source] + operand_table = Table( + source_table.columns + self.table(step.operands).columns + ) + + for reader, ctx in context: + operand_table.append(reader.row + ctx.eval_tuple(operands)) + + context = self.context({source: operand_table}) + + group = None + start = 0 + end = 1 + length = len(context.tables[source]) + table = self.table(step.group + step.aggregations) + + for i in range(length): + context.set_index(source, i) + key = context.eval_tuple(group_by) + group = key if group is None else group + end += 1 + + if i == length - 1: + context.set_range(source, start, end - 1) + elif key != group: + context.set_range(source, start, end - 2) + else: + continue + + table.append(group + context.eval_tuple(aggregations)) + group = key + start = end - 2 + + return self.scan(step, self.context({source: table})) + + def sort(self, step, context): + table = list(context.tables)[0] + key = self.generate_tuple(step.key) + context.sort(table, key) + return self.scan(step, context) + + +def _cast_py(self, expression): + to = expression.args["to"].this + this = self.sql(expression, "this") + + if to == exp.DataType.Type.DATE: + return f"datetime.date.fromisoformat({this})" + if to == exp.DataType.Type.TEXT: + return f"str({this})" + raise NotImplementedError + + +def _column_py(self, expression): + table = self.sql(expression, "table") + this = self.sql(expression, "this") + return f"scope[{table}][{this}]" + + +def _interval_py(self, expression): + this = self.sql(expression, "this") + unit = expression.text("unit").upper() + if unit == "DAY": + return f"datetime.timedelta(days=float({this}))" + raise NotImplementedError + + +def _like_py(self, expression): + this = self.sql(expression, "this") + expression = self.sql(expression, "expression") + return f"""re.match({expression}.replace("_", ".").replace("%", ".*"), {this})""" + + +def _ordered_py(self, expression): + this = self.sql(expression, "this") + desc = expression.args.get("desc") + return f"desc({this})" if desc else this + + +class Python(Dialect): + class Tokenizer(Tokenizer): + ESCAPE = "\\" + + class Generator(Generator): + TRANSFORMS = { + exp.Alias: lambda self, e: self.sql(e.this), + exp.Array: inline_array_sql, + exp.And: lambda self, e: self.binary(e, "and"), + exp.Cast: _cast_py, + exp.Column: _column_py, + exp.EQ: lambda self, e: self.binary(e, "=="), + exp.Interval: _interval_py, + exp.Is: lambda self, e: self.binary(e, "is"), + exp.Like: _like_py, + exp.Not: lambda self, e: f"not {self.sql(e.this)}", + exp.Null: lambda *_: "None", + exp.Or: lambda self, e: self.binary(e, "or"), + exp.Ordered: _ordered_py, + exp.Star: lambda *_: "1", + } + + def case_sql(self, expression): + this = self.sql(expression, "this") + chain = self.sql(expression, "default") or "None" + + for e in reversed(expression.args["ifs"]): + true = self.sql(e, "true") + condition = self.sql(e, "this") + condition = f"{this} = ({condition})" if this else condition + chain = f"{true} if {condition} else ({chain})" + + return chain diff --git a/sqlglot/executor/table.py b/sqlglot/executor/table.py new file mode 100644 index 0000000..6df49f7 --- /dev/null +++ b/sqlglot/executor/table.py @@ -0,0 +1,81 @@ +class Table: + def __init__(self, *columns, rows=None): + self.columns = tuple(columns if isinstance(columns[0], str) else columns[0]) + self.rows = rows or [] + if rows: + assert len(rows[0]) == len(self.columns) + self.reader = RowReader(self.columns) + self.range_reader = RangeReader(self) + + def append(self, row): + assert len(row) == len(self.columns) + self.rows.append(row) + + def pop(self): + self.rows.pop() + + @property + def width(self): + return len(self.columns) + + def __len__(self): + return len(self.rows) + + def __iter__(self): + return TableIter(self) + + def __getitem__(self, index): + self.reader.row = self.rows[index] + return self.reader + + def __repr__(self): + widths = {column: len(column) for column in self.columns} + lines = [" ".join(column for column in self.columns)] + + for i, row in enumerate(self): + if i > 10: + break + + lines.append( + " ".join( + str(row[column]).rjust(widths[column])[0 : widths[column]] + for column in self.columns + ) + ) + return "\n".join(lines) + + +class TableIter: + def __init__(self, table): + self.table = table + self.index = -1 + + def __iter__(self): + return self + + def __next__(self): + self.index += 1 + if self.index < len(self.table): + return self.table[self.index] + raise StopIteration + + +class RangeReader: + def __init__(self, table): + self.table = table + self.range = range(0) + + def __len__(self): + return len(self.range) + + def __getitem__(self, column): + return (self.table[i][column] for i in self.range) + + +class RowReader: + def __init__(self, columns): + self.columns = {column: i for i, column in enumerate(columns)} + self.row = None + + def __getitem__(self, column): + return self.row[self.columns[column]] diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py new file mode 100644 index 0000000..7acc63d --- /dev/null +++ b/sqlglot/expressions.py @@ -0,0 +1,2945 @@ +import inspect +import re +import sys +from collections import deque +from copy import deepcopy +from enum import auto + +from sqlglot.errors import ParseError +from sqlglot.helper import AutoName, camel_to_snake_case, ensure_list + + +class _Expression(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + klass.key = clsname.lower() + return klass + + +class Expression(metaclass=_Expression): + """ + The base class for all expressions in a syntax tree. + + Attributes: + arg_types (dict): determines arguments supported by this expression. + The key in a dictionary defines a unique key of an argument using + which the argument's value can be retrieved. The value is a boolean + flag which indicates whether the argument's value is required (True) + or optional (False). + """ + + key = None + arg_types = {"this": True} + __slots__ = ("args", "parent", "arg_key") + + def __init__(self, **args): + self.args = args + self.parent = None + self.arg_key = None + + for arg_key, value in self.args.items(): + self._set_parent(arg_key, value) + + def __eq__(self, other): + return type(self) is type(other) and _norm_args(self) == _norm_args(other) + + def __hash__(self): + return hash( + ( + self.key, + tuple( + (k, tuple(v) if isinstance(v, list) else v) + for k, v in _norm_args(self).items() + ), + ) + ) + + @property + def this(self): + return self.args.get("this") + + @property + def expression(self): + return self.args.get("expression") + + @property + def expressions(self): + return self.args.get("expressions") or [] + + def text(self, key): + field = self.args.get(key) + if isinstance(field, str): + return field + if isinstance(field, (Identifier, Literal, Var)): + return field.this + return "" + + @property + def is_string(self): + return isinstance(self, Literal) and self.args["is_string"] + + @property + def is_number(self): + return isinstance(self, Literal) and not self.args["is_string"] + + @property + def is_int(self): + if self.is_number: + try: + int(self.name) + return True + except ValueError: + pass + return False + + @property + def alias(self): + if isinstance(self.args.get("alias"), TableAlias): + return self.args["alias"].name + return self.text("alias") + + @property + def name(self): + return self.text("this") + + @property + def alias_or_name(self): + return self.alias or self.name + + def __deepcopy__(self, memo): + return self.__class__(**deepcopy(self.args)) + + def copy(self): + new = deepcopy(self) + for item, parent, _ in new.bfs(): + if isinstance(item, Expression) and parent: + item.parent = parent + return new + + def set(self, arg_key, value): + """ + Sets `arg` to `value`. + + Args: + arg_key (str): name of the expression arg + value: value to set the arg to. + """ + self.args[arg_key] = value + self._set_parent(arg_key, value) + + def _set_parent(self, arg_key, value): + if isinstance(value, Expression): + value.parent = self + value.arg_key = arg_key + elif isinstance(value, list): + for v in value: + if isinstance(v, Expression): + v.parent = self + v.arg_key = arg_key + + @property + def depth(self): + """ + Returns the depth of this tree. + """ + if self.parent: + return self.parent.depth + 1 + return 0 + + def find(self, *expression_types, bfs=True): + """ + Returns the first node in this tree which matches at least one of + the specified types. + + Args: + expression_types (type): the expression type(s) to match. + + Returns: + the node which matches the criteria or None if no node matching + the criteria was found. + """ + return next(self.find_all(*expression_types, bfs=bfs), None) + + def find_all(self, *expression_types, bfs=True): + """ + Returns a generator object which visits all nodes in this tree and only + yields those that match at least one of the specified expression types. + + Args: + expression_types (type): the expression type(s) to match. + + Returns: + the generator object. + """ + for expression, _, _ in self.walk(bfs=bfs): + if isinstance(expression, expression_types): + yield expression + + def find_ancestor(self, *expression_types): + """ + Returns a nearest parent matching expression_types. + + Args: + expression_types (type): the expression type(s) to match. + + Returns: + the parent node + """ + ancestor = self.parent + while ancestor and not isinstance(ancestor, expression_types): + ancestor = ancestor.parent + return ancestor + + @property + def parent_select(self): + """ + Returns the parent select statement. + """ + return self.find_ancestor(Select) + + def walk(self, bfs=True): + """ + Returns a generator object which visits all nodes in this tree. + + Args: + bfs (bool): if set to True the BFS traversal order will be applied, + otherwise the DFS traversal will be used instead. + + Returns: + the generator object. + """ + if bfs: + yield from self.bfs() + else: + yield from self.dfs() + + def dfs(self, parent=None, key=None, prune=None): + """ + Returns a generator object which visits all nodes in this tree in + the DFS (Depth-first) order. + + Returns: + the generator object. + """ + parent = parent or self.parent + yield self, parent, key + if prune and prune(self, parent, key): + return + + for k, v in self.args.items(): + nodes = ensure_list(v) + + for node in nodes: + if isinstance(node, Expression): + yield from node.dfs(self, k, prune) + + def bfs(self, prune=None): + """ + Returns a generator object which visits all nodes in this tree in + the BFS (Breadth-first) order. + + Returns: + the generator object. + """ + queue = deque([(self, self.parent, None)]) + + while queue: + item, parent, key = queue.popleft() + + yield item, parent, key + if prune and prune(item, parent, key): + continue + + if isinstance(item, Expression): + for k, v in item.args.items(): + nodes = ensure_list(v) + + for node in nodes: + if isinstance(node, Expression): + queue.append((node, item, k)) + + def unnest(self): + """ + Returns the first non parenthesis child or self. + """ + expression = self + while isinstance(expression, Paren): + expression = expression.this + return expression + + def unnest_operands(self): + """ + Returns unnested operands as a tuple. + """ + return tuple(arg.unnest() for arg in self.args.values() if arg) + + def flatten(self, unnest=True): + """ + Returns a generator which yields child nodes who's parents are the same class. + + A AND B AND C -> [A, B, C] + """ + for node, _, _ in self.dfs( + prune=lambda n, p, *_: p and not isinstance(n, self.__class__) + ): + if not isinstance(node, self.__class__): + yield node.unnest() if unnest else node + + def __str__(self): + return self.sql() + + def __repr__(self): + return self.to_s() + + def sql(self, dialect=None, **opts): + """ + Returns SQL string representation of this tree. + + Args + dialect (str): the dialect of the output SQL string + (eg. "spark", "hive", "presto", "mysql"). + opts (dict): other :class:`~sqlglot.generator.Generator` options. + + Returns + the SQL string. + """ + from sqlglot.dialects import Dialect + + return Dialect.get_or_raise(dialect)().generate(self, **opts) + + def to_s(self, hide_missing=True, level=0): + indent = "" if not level else "\n" + indent += "".join([" "] * level) + left = f"({self.key.upper()} " + + args = { + k: ", ".join( + v.to_s(hide_missing=hide_missing, level=level + 1) + if hasattr(v, "to_s") + else str(v) + for v in ensure_list(vs) + if v is not None + ) + for k, vs in self.args.items() + } + args = {k: v for k, v in args.items() if v or not hide_missing} + + right = ", ".join(f"{k}: {v}" for k, v in args.items()) + right += ")" + + return indent + left + right + + def transform(self, fun, *args, copy=True, **kwargs): + """ + Recursively visits all tree nodes (excluding already transformed ones) + and applies the given transformation function to each node. + + Args: + fun (function): a function which takes a node as an argument and returns a + new transformed node or the same node without modifications. + copy (bool): if set to True a new tree instance is constructed, otherwise the tree is + modified in place. + + Returns: + the transformed tree. + """ + node = self.copy() if copy else self + new_node = fun(node, *args, **kwargs) + + if new_node is None: + raise ValueError("A transformed node cannot be None") + if not isinstance(new_node, Expression): + return new_node + if new_node is not node: + new_node.parent = node.parent + return new_node + + replace_children( + new_node, lambda child: child.transform(fun, *args, copy=False, **kwargs) + ) + return new_node + + def replace(self, expression): + """ + Swap out this expression with a new expression. + + For example:: + + >>> tree = Select().select("x").from_("tbl") + >>> tree.find(Column).replace(Column(this="y")) + (COLUMN this: y) + >>> tree.sql() + 'SELECT y FROM tbl' + + Args: + expression (Expression): new node + + Returns : + the new expression or expressions + """ + if not self.parent: + return expression + + parent = self.parent + self.parent = None + + replace_children(parent, lambda child: expression if child is self else child) + return expression + + def assert_is(self, type_): + """ + Assert that this `Expression` is an instance of `type_`. + + If it is NOT an instance of `type_`, this raises an assertion error. + Otherwise, this returns this expression. + + Examples: + This is useful for type security in chained expressions: + + >>> import sqlglot + >>> sqlglot.parse_one("SELECT x from y").assert_is(Select).select("z").sql() + 'SELECT x, z FROM y' + """ + assert isinstance(self, type_) + return self + + +class Condition(Expression): + def and_(self, *expressions, dialect=None, **opts): + """ + AND this condition with one or multiple expressions. + + Example: + >>> condition("x=1").and_("y=1").sql() + 'x = 1 AND y = 1' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + And: the new condition. + """ + return and_(self, *expressions, dialect=dialect, **opts) + + def or_(self, *expressions, dialect=None, **opts): + """ + OR this condition with one or multiple expressions. + + Example: + >>> condition("x=1").or_("y=1").sql() + 'x = 1 OR y = 1' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + dialect (str): the dialect used to parse the input expression. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Or: the new condition. + """ + return or_(self, *expressions, dialect=dialect, **opts) + + def not_(self): + """ + Wrap this condition with NOT. + + Example: + >>> condition("x=1").not_().sql() + 'NOT x = 1' + + Returns: + Not: the new condition. + """ + return not_(self) + + +class Predicate(Condition): + """Relationships like x = y, x > 1, x >= y.""" + + +class DerivedTable(Expression): + @property + def alias_column_names(self): + table_alias = self.args.get("alias") + if not table_alias: + return [] + column_list = table_alias.assert_is(TableAlias).args.get("columns") or [] + return [c.name for c in column_list] + + @property + def selects(self): + alias = self.args.get("alias") + + if alias: + return alias.columns + return [] + + @property + def named_selects(self): + return [select.alias_or_name for select in self.selects] + + +class Annotation(Expression): + arg_types = { + "this": True, + "expression": True, + } + + +class Cache(Expression): + arg_types = { + "with": False, + "this": True, + "lazy": False, + "options": False, + "expression": False, + } + + +class Uncache(Expression): + arg_types = {"this": True, "exists": False} + + +class Create(Expression): + arg_types = { + "with": False, + "this": True, + "kind": True, + "expression": False, + "exists": False, + "properties": False, + "temporary": False, + "replace": False, + "unique": False, + } + + +class CharacterSet(Expression): + arg_types = {"this": True, "default": False} + + +class With(Expression): + arg_types = {"expressions": True, "recursive": False} + + +class WithinGroup(Expression): + arg_types = {"this": True, "expression": False} + + +class CTE(DerivedTable): + arg_types = {"this": True, "alias": True} + + +class TableAlias(Expression): + arg_types = {"this": False, "columns": False} + + @property + def columns(self): + return self.args.get("columns") or [] + + +class BitString(Condition): + pass + + +class Column(Condition): + arg_types = {"this": True, "table": False} + + @property + def table(self): + return self.text("table") + + +class ColumnDef(Expression): + arg_types = { + "this": True, + "kind": True, + "constraints": False, + } + + +class ColumnConstraint(Expression): + arg_types = {"this": False, "kind": True} + + +class AutoIncrementColumnConstraint(Expression): + pass + + +class CheckColumnConstraint(Expression): + pass + + +class CollateColumnConstraint(Expression): + pass + + +class CommentColumnConstraint(Expression): + pass + + +class DefaultColumnConstraint(Expression): + pass + + +class NotNullColumnConstraint(Expression): + pass + + +class PrimaryKeyColumnConstraint(Expression): + pass + + +class UniqueColumnConstraint(Expression): + pass + + +class Constraint(Expression): + arg_types = {"this": True, "expressions": True} + + +class Delete(Expression): + arg_types = {"with": False, "this": True, "where": False} + + +class Drop(Expression): + arg_types = {"this": False, "kind": False, "exists": False} + + +class Filter(Expression): + arg_types = {"this": True, "expression": True} + + +class Check(Expression): + pass + + +class ForeignKey(Expression): + arg_types = { + "expressions": True, + "reference": False, + "delete": False, + "update": False, + } + + +class Unique(Expression): + arg_types = {"expressions": True} + + +class From(Expression): + arg_types = {"expressions": True} + + +class Having(Expression): + pass + + +class Hint(Expression): + arg_types = {"expressions": True} + + +class Identifier(Expression): + arg_types = {"this": True, "quoted": False} + + @property + def quoted(self): + return bool(self.args.get("quoted")) + + def __eq__(self, other): + return isinstance(other, self.__class__) and _norm_arg(self.this) == _norm_arg( + other.this + ) + + def __hash__(self): + return hash((self.key, self.this.lower())) + + +class Index(Expression): + arg_types = {"this": False, "table": False, "where": False, "columns": False} + + +class Insert(Expression): + arg_types = { + "with": False, + "this": True, + "expression": True, + "overwrite": False, + "exists": False, + "partition": False, + } + + +# https://dev.mysql.com/doc/refman/8.0/en/charset-introducer.html +class Introducer(Expression): + arg_types = {"this": True, "expression": True} + + +class Partition(Expression): + pass + + +class Fetch(Expression): + arg_types = {"direction": False, "count": True} + + +class Group(Expression): + arg_types = { + "expressions": False, + "grouping_sets": False, + "cube": False, + "rollup": False, + } + + +class Lambda(Expression): + arg_types = {"this": True, "expressions": True} + + +class Limit(Expression): + arg_types = {"this": False, "expression": True} + + +class Literal(Condition): + arg_types = {"this": True, "is_string": True} + + def __eq__(self, other): + return ( + isinstance(other, Literal) + and self.this == other.this + and self.args["is_string"] == other.args["is_string"] + ) + + def __hash__(self): + return hash((self.key, self.this, self.args["is_string"])) + + @classmethod + def number(cls, number): + return cls(this=str(number), is_string=False) + + @classmethod + def string(cls, string): + return cls(this=str(string), is_string=True) + + +class Join(Expression): + arg_types = { + "this": True, + "on": False, + "side": False, + "kind": False, + "using": False, + } + + @property + def kind(self): + return self.text("kind").upper() + + @property + def side(self): + return self.text("side").upper() + + def on(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the ON expressions. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("JOIN x", into=Join).on("y = 1").sql() + 'JOIN x ON y = 1' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append (bool): if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Join: the modified join expression. + """ + join = _apply_conjunction_builder( + *expressions, + instance=self, + arg="on", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + if join.kind == "CROSS": + join.set("kind", None) + + return join + + +class Lateral(DerivedTable): + arg_types = {"this": True, "outer": False, "alias": False} + + +# Clickhouse FROM FINAL modifier +# https://clickhouse.com/docs/en/sql-reference/statements/select/from/#final-modifier +class Final(Expression): + pass + + +class Offset(Expression): + arg_types = {"this": False, "expression": True} + + +class Order(Expression): + arg_types = {"this": False, "expressions": True} + + +# hive specific sorts +# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+SortBy +class Cluster(Order): + pass + + +class Distribute(Order): + pass + + +class Sort(Order): + pass + + +class Ordered(Expression): + arg_types = {"this": True, "desc": True, "nulls_first": True} + + +class Properties(Expression): + arg_types = {"expressions": True} + + +class Property(Expression): + arg_types = {"this": True, "value": True} + + +class TableFormatProperty(Property): + pass + + +class PartitionedByProperty(Property): + pass + + +class FileFormatProperty(Property): + pass + + +class LocationProperty(Property): + pass + + +class EngineProperty(Property): + pass + + +class AutoIncrementProperty(Property): + pass + + +class CharacterSetProperty(Property): + arg_types = {"this": True, "value": True, "default": True} + + +class CollateProperty(Property): + pass + + +class SchemaCommentProperty(Property): + pass + + +class AnonymousProperty(Property): + pass + + +class Qualify(Expression): + pass + + +class Reference(Expression): + arg_types = {"this": True, "expressions": True} + + +class Table(Expression): + arg_types = {"this": True, "db": False, "catalog": False} + + +class Tuple(Expression): + arg_types = {"expressions": False} + + +class Subqueryable: + def subquery(self, alias=None, copy=True): + """ + Convert this expression to an aliased expression that can be used as a Subquery. + + Example: + >>> subquery = Select().select("x").from_("tbl").subquery() + >>> Select().select("x").from_(subquery).sql() + 'SELECT x FROM (SELECT x FROM tbl)' + + Args: + alias (str or Identifier): an optional alias for the subquery + copy (bool): if `False`, modify this expression instance in-place. + + Returns: + Alias: the subquery + """ + instance = _maybe_copy(self, copy) + return Subquery( + this=instance, + alias=TableAlias(this=to_identifier(alias)), + ) + + @property + def ctes(self): + with_ = self.args.get("with") + if not with_: + return [] + return with_.expressions + + def with_( + self, + alias, + as_, + recursive=None, + append=True, + dialect=None, + copy=True, + **opts, + ): + """ + Append to or set the common table expressions. + + Example: + >>> Select().with_("tbl2", as_="SELECT * FROM tbl").select("x").from_("tbl2").sql() + 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' + + Args: + alias (str or Expression): the SQL code string to parse as the table name. + If an `Expression` instance is passed, this is used as-is. + as_ (str or Expression): the SQL code string to parse as the table expression. + If an `Expression` instance is passed, it will be used as-is. + recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + alias_expression = maybe_parse( + alias, + dialect=dialect, + into=TableAlias, + **opts, + ) + as_expression = maybe_parse( + as_, + dialect=dialect, + **opts, + ) + cte = CTE( + this=as_expression, + alias=alias_expression, + ) + return _apply_child_list_builder( + cte, + instance=self, + arg="with", + append=append, + copy=copy, + into=With, + properties={"recursive": recursive or False}, + ) + + +QUERY_MODIFIERS = { + "laterals": False, + "joins": False, + "where": False, + "group": False, + "having": False, + "qualify": False, + "window": False, + "distribute": False, + "sort": False, + "cluster": False, + "order": False, + "limit": False, + "offset": False, +} + + +class Union(Subqueryable, Expression): + arg_types = { + "with": False, + "this": True, + "expression": True, + "distinct": False, + **QUERY_MODIFIERS, + } + + @property + def named_selects(self): + return self.args["this"].unnest().named_selects + + @property + def left(self): + return self.this + + @property + def right(self): + return self.expression + + +class Except(Union): + pass + + +class Intersect(Union): + pass + + +class Unnest(DerivedTable): + arg_types = { + "expressions": True, + "ordinality": False, + "alias": False, + } + + +class Update(Expression): + arg_types = { + "with": False, + "this": True, + "expressions": True, + "from": False, + "where": False, + } + + +class Values(Expression): + arg_types = {"expressions": True} + + +class Var(Expression): + pass + + +class Schema(Expression): + arg_types = {"this": False, "expressions": True} + + +class Select(Subqueryable, Expression): + arg_types = { + "with": False, + "expressions": False, + "hint": False, + "distinct": False, + "from": False, + **QUERY_MODIFIERS, + } + + def from_(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the FROM expression. + + Example: + >>> Select().from_("tbl").select("x").sql() + 'SELECT x FROM tbl' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `From` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `From`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `From` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="from", + append=append, + copy=copy, + prefix="FROM", + into=From, + dialect=dialect, + **opts, + ) + + def group_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the GROUP BY expression. + + Example: + >>> Select().from_("tbl").select("x", "COUNT(1)").group_by("x").sql() + 'SELECT x, COUNT(1) FROM tbl GROUP BY x' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Group`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Group` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="group", + append=append, + copy=copy, + prefix="GROUP BY", + into=Group, + dialect=dialect, + **opts, + ) + + def order_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the ORDER BY expression. + + Example: + >>> Select().from_("tbl").select("x").order_by("x DESC").sql() + 'SELECT x FROM tbl ORDER BY x DESC' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Order`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="order", + append=append, + copy=copy, + prefix="ORDER BY", + into=Order, + dialect=dialect, + **opts, + ) + + def sort_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the SORT BY expression. + + Example: + >>> Select().from_("tbl").select("x").sort_by("x DESC").sql() + 'SELECT x FROM tbl SORT BY x DESC' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `SORT`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="sort", + append=append, + copy=copy, + prefix="SORT BY", + into=Sort, + dialect=dialect, + **opts, + ) + + def cluster_by(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Set the CLUSTER BY expression. + + Example: + >>> Select().from_("tbl").select("x").cluster_by("x DESC").sql() + 'SELECT x FROM tbl CLUSTER BY x DESC' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If a `Group` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Cluster`. + append (bool): if `True`, add to any existing expressions. + Otherwise, this flattens all the `Order` expression into a single expression. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_child_list_builder( + *expressions, + instance=self, + arg="cluster", + append=append, + copy=copy, + prefix="CLUSTER BY", + into=Cluster, + dialect=dialect, + **opts, + ) + + def limit(self, expression, dialect=None, copy=True, **opts): + """ + Set the LIMIT expression. + + Example: + >>> Select().from_("tbl").select("x").limit(10).sql() + 'SELECT x FROM tbl LIMIT 10' + + Args: + expression (str or int or Expression): the SQL code string to parse. + This can also be an integer. + If a `Limit` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Limit`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="limit", + into=Limit, + prefix="LIMIT", + dialect=dialect, + copy=copy, + **opts, + ) + + def offset(self, expression, dialect=None, copy=True, **opts): + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").offset(10).sql() + 'SELECT x FROM tbl OFFSET 10' + + Args: + expression (str or int or Expression): the SQL code string to parse. + This can also be an integer. + If a `Offset` instance is passed, this is used as-is. + If another `Expression` instance is passed, it will be wrapped in a `Offset`. + dialect (str): the dialect used to parse the input expression. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_builder( + expression=expression, + instance=self, + arg="offset", + into=Offset, + prefix="OFFSET", + dialect=dialect, + copy=copy, + **opts, + ) + + def select(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the SELECT expressions. + + Example: + >>> Select().select("x", "y").sql() + 'SELECT x, y' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="expressions", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + def lateral(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the LATERAL expressions. + + Example: + >>> Select().select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl").sql() + 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_list_builder( + *expressions, + instance=self, + arg="laterals", + append=append, + into=Lateral, + prefix="LATERAL VIEW", + dialect=dialect, + copy=copy, + **opts, + ) + + def join( + self, + expression, + on=None, + append=True, + join_type=None, + join_alias=None, + dialect=None, + copy=True, + **opts, + ): + """ + Append to or set the JOIN expressions. + + Example: + >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql() + 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y' + + Use `join_type` to change the type of join: + + >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql() + 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' + + Args: + expression (str or Expression): the SQL code string to parse. + If an `Expression` instance is passed, it will be used as-is. + on (str or Expression): optionally specify the join criteria as a SQL string. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, add to any existing expressions. + Otherwise, this resets the expressions. + join_type (str): If set, alter the parsed join type + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + parse_args = {"dialect": dialect, **opts} + + try: + expression = maybe_parse(expression, into=Join, prefix="JOIN", **parse_args) + except ParseError: + expression = maybe_parse(expression, into=(Join, Expression), **parse_args) + + join = expression if isinstance(expression, Join) else Join(this=expression) + + if isinstance(join.this, Select): + join.this.replace(join.this.subquery()) + + if join_type: + side, kind = maybe_parse(join_type, into="JOIN_TYPE", **parse_args) + if side: + join.set("side", side.text) + if kind: + join.set("kind", kind.text) + + if on: + on = and_(*ensure_list(on), dialect=dialect, **opts) + join.set("on", on) + + if join_alias: + join.set("this", alias_(join.args["this"], join_alias, table=True)) + return _apply_list_builder( + join, + instance=self, + arg="joins", + append=append, + copy=copy, + **opts, + ) + + def where(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the WHERE expressions. + + Example: + >>> Select().select("x").from_("tbl").where("x = 'a' OR x < 'b'").sql() + "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append (bool): if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="where", + append=append, + into=Where, + dialect=dialect, + copy=copy, + **opts, + ) + + def having(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the HAVING expressions. + + Example: + >>> Select().select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 3").sql() + 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + Multiple expressions are combined with an AND operator. + append (bool): if `True`, AND the new expressions to any existing expression. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Select: the modified expression. + """ + return _apply_conjunction_builder( + *expressions, + instance=self, + arg="having", + append=append, + into=Having, + dialect=dialect, + copy=copy, + **opts, + ) + + def distinct(self, distinct=True, copy=True): + """ + Set the OFFSET expression. + + Example: + >>> Select().from_("tbl").select("x").distinct().sql() + 'SELECT DISTINCT x FROM tbl' + + Args: + distinct (bool): whether the Select should be distinct + copy (bool): if `False`, modify this expression instance in-place. + + Returns: + Select: the modified expression. + """ + instance = _maybe_copy(self, copy) + instance.set("distinct", Distinct() if distinct else None) + return instance + + def ctas(self, table, properties=None, dialect=None, copy=True, **opts): + """ + Convert this expression to a CREATE TABLE AS statement. + + Example: + >>> Select().select("*").from_("tbl").ctas("x").sql() + 'CREATE TABLE x AS SELECT * FROM tbl' + + Args: + table (str or Expression): the SQL code string to parse as the table name. + If another `Expression` instance is passed, it will be used as-is. + properties (dict): an optional mapping of table properties + dialect (str): the dialect used to parse the input table. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input table. + + Returns: + Create: the CREATE TABLE AS expression + """ + instance = _maybe_copy(self, copy) + table_expression = maybe_parse( + table, + into=Table, + dialect=dialect, + **opts, + ) + properties_expression = None + if properties: + properties_str = " ".join( + [ + f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" + for k, v in properties.items() + ] + ) + properties_expression = maybe_parse( + properties_str, + into=Properties, + dialect=dialect, + **opts, + ) + + return Create( + this=table_expression, + kind="table", + expression=instance, + properties=properties_expression, + ) + + @property + def named_selects(self): + return [e.alias_or_name for e in self.expressions if e.alias_or_name] + + @property + def selects(self): + return self.expressions + + +class Subquery(DerivedTable): + arg_types = { + "this": True, + "alias": False, + **QUERY_MODIFIERS, + } + + def unnest(self): + """ + Returns the first non subquery. + """ + expression = self + while isinstance(expression, Subquery): + expression = expression.this + return expression + + +class TableSample(Expression): + arg_types = { + "this": False, + "method": False, + "bucket_numerator": False, + "bucket_denominator": False, + "bucket_field": False, + "percent": False, + "rows": False, + "size": False, + } + + +class Window(Expression): + arg_types = { + "this": True, + "partition_by": False, + "order": False, + "spec": False, + "alias": False, + } + + +class WindowSpec(Expression): + arg_types = { + "kind": False, + "start": False, + "start_side": False, + "end": False, + "end_side": False, + } + + +class Where(Expression): + pass + + +class Star(Expression): + arg_types = {"except": False, "replace": False} + + @property + def name(self): + return "*" + + +class Placeholder(Expression): + arg_types = {} + + +class Null(Condition): + arg_types = {} + + +class Boolean(Condition): + pass + + +class DataType(Expression): + arg_types = { + "this": True, + "expressions": False, + "nested": False, + } + + class Type(AutoName): + CHAR = auto() + NCHAR = auto() + VARCHAR = auto() + NVARCHAR = auto() + TEXT = auto() + BINARY = auto() + INT = auto() + TINYINT = auto() + SMALLINT = auto() + BIGINT = auto() + FLOAT = auto() + DOUBLE = auto() + DECIMAL = auto() + BOOLEAN = auto() + JSON = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + DATE = auto() + DATETIME = auto() + ARRAY = auto() + MAP = auto() + UUID = auto() + GEOGRAPHY = auto() + STRUCT = auto() + NULLABLE = auto() + + @classmethod + def build(cls, dtype, **kwargs): + return DataType( + this=dtype + if isinstance(dtype, DataType.Type) + else DataType.Type[dtype.upper()], + **kwargs, + ) + + +class StructKwarg(Expression): + arg_types = {"this": True, "expression": True} + + +# WHERE x <OP> EXISTS|ALL|ANY|SOME(SELECT ...) +class SubqueryPredicate(Predicate): + pass + + +class All(SubqueryPredicate): + pass + + +class Any(SubqueryPredicate): + pass + + +class Exists(SubqueryPredicate): + pass + + +# Commands to interact with the databases or engines +# These expressions don't truly parse the expression and consume +# whatever exists as a string until the end or a semicolon +class Command(Expression): + arg_types = {"this": True, "expression": False} + + +# Binary Expressions +# (ADD a b) +# (FROM table selects) +class Binary(Expression): + arg_types = {"this": True, "expression": True} + + @property + def left(self): + return self.this + + @property + def right(self): + return self.expression + + +class Add(Binary): + pass + + +class Connector(Binary, Condition): + pass + + +class And(Connector): + pass + + +class Or(Connector): + pass + + +class BitwiseAnd(Binary): + pass + + +class BitwiseLeftShift(Binary): + pass + + +class BitwiseOr(Binary): + pass + + +class BitwiseRightShift(Binary): + pass + + +class BitwiseXor(Binary): + pass + + +class Div(Binary): + pass + + +class Dot(Binary): + pass + + +class DPipe(Binary): + pass + + +class EQ(Binary, Predicate): + pass + + +class Escape(Binary): + pass + + +class GT(Binary, Predicate): + pass + + +class GTE(Binary, Predicate): + pass + + +class ILike(Binary, Predicate): + pass + + +class IntDiv(Binary): + pass + + +class Is(Binary, Predicate): + pass + + +class Like(Binary, Predicate): + pass + + +class LT(Binary, Predicate): + pass + + +class LTE(Binary, Predicate): + pass + + +class Mod(Binary): + pass + + +class Mul(Binary): + pass + + +class NEQ(Binary, Predicate): + pass + + +class Sub(Binary): + pass + + +# Unary Expressions +# (NOT a) +class Unary(Expression): + pass + + +class BitwiseNot(Unary): + pass + + +class Not(Unary, Condition): + pass + + +class Paren(Unary, Condition): + pass + + +class Neg(Unary): + pass + + +# Special Functions +class Alias(Expression): + arg_types = {"this": True, "alias": False} + + +class Aliases(Expression): + arg_types = {"this": True, "expressions": True} + + @property + def aliases(self): + return self.expressions + + +class AtTimeZone(Expression): + arg_types = {"this": True, "zone": True} + + +class Between(Predicate): + arg_types = {"this": True, "low": True, "high": True} + + +class Bracket(Condition): + arg_types = {"this": True, "expressions": True} + + +class Distinct(Expression): + arg_types = {"this": False, "on": False} + + +class In(Predicate): + arg_types = {"this": True, "expressions": False, "query": False, "unnest": False} + + +class TimeUnit(Expression): + """Automatically converts unit arg into a var.""" + + arg_types = {"unit": False} + + def __init__(self, **args): + unit = args.get("unit") + if isinstance(unit, Column): + args["unit"] = Var(this=unit.name) + elif isinstance(unit, Week): + unit.set("this", Var(this=unit.this.name)) + super().__init__(**args) + + +class Interval(TimeUnit): + arg_types = {"this": True, "unit": False} + + +class IgnoreNulls(Expression): + pass + + +# Functions +class Func(Condition): + """ + The base class for all function expressions. + + Attributes + is_var_len_args (bool): if set to True the last argument defined in + arg_types will be treated as a variable length argument and the + argument's value will be stored as a list. + _sql_names (list): determines the SQL name (1st item in the list) and + aliases (subsequent items) for this function expression. These + values are used to map this node to a name during parsing as well + as to provide the function's name during SQL string generation. By + default the SQL name is set to the expression's class name transformed + to snake case. + """ + + is_var_len_args = False + + @classmethod + def from_arg_list(cls, args): + args_num = len(args) + + all_arg_keys = list(cls.arg_types) + # If this function supports variable length argument treat the last argument as such. + non_var_len_arg_keys = ( + all_arg_keys[:-1] if cls.is_var_len_args else all_arg_keys + ) + + args_dict = {} + arg_idx = 0 + for arg_key in non_var_len_arg_keys: + if arg_idx >= args_num: + break + if args[arg_idx] is not None: + args_dict[arg_key] = args[arg_idx] + arg_idx += 1 + + if arg_idx < args_num and cls.is_var_len_args: + args_dict[all_arg_keys[-1]] = args[arg_idx:] + return cls(**args_dict) + + @classmethod + def sql_names(cls): + if cls is Func: + raise NotImplementedError( + "SQL name is only supported by concrete function implementations" + ) + if not hasattr(cls, "_sql_names"): + cls._sql_names = [camel_to_snake_case(cls.__name__)] + return cls._sql_names + + @classmethod + def sql_name(cls): + return cls.sql_names()[0] + + @classmethod + def default_parser_mappings(cls): + return {name: cls.from_arg_list for name in cls.sql_names()} + + +class AggFunc(Func): + pass + + +class Abs(Func): + pass + + +class Anonymous(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class ApproxDistinct(AggFunc): + arg_types = {"this": True, "accuracy": False} + + +class Array(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +class ArrayAgg(AggFunc): + pass + + +class ArrayAll(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayAny(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayContains(Func): + arg_types = {"this": True, "expression": True} + + +class ArrayFilter(Func): + arg_types = {"this": True, "expression": True} + _sql_names = ["FILTER", "ARRAY_FILTER"] + + +class ArraySize(Func): + pass + + +class ArraySort(Func): + arg_types = {"this": True, "expression": False} + + +class ArraySum(Func): + pass + + +class ArrayUnionAgg(AggFunc): + pass + + +class Avg(AggFunc): + pass + + +class AnyValue(AggFunc): + pass + + +class Case(Func): + arg_types = {"this": False, "ifs": True, "default": False} + + +class Cast(Func): + arg_types = {"this": True, "to": True} + + +class TryCast(Cast): + pass + + +class Ceil(Func): + _sql_names = ["CEIL", "CEILING"] + + +class Coalesce(Func): + arg_types = {"this": True, "expressions": False} + is_var_len_args = True + + +class ConcatWs(Func): + arg_types = {"expressions": False} + is_var_len_args = True + + +class Count(AggFunc): + pass + + +class CurrentDate(Func): + arg_types = {"this": False} + + +class CurrentDatetime(Func): + arg_types = {"this": False} + + +class CurrentTime(Func): + arg_types = {"this": False} + + +class CurrentTimestamp(Func): + arg_types = {"this": False} + + +class DateAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DateTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class DatetimeAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class DatetimeTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class Extract(Func): + arg_types = {"this": True, "expression": True} + + +class TimestampAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimestampTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class TimeAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeSub(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeDiff(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TimeTrunc(Func, TimeUnit): + arg_types = {"this": True, "unit": True, "zone": False} + + +class DateStrToDate(Func): + pass + + +class DateToDateStr(Func): + pass + + +class DateToDi(Func): + pass + + +class Day(Func): + pass + + +class DiToDate(Func): + pass + + +class Exp(Func): + pass + + +class Explode(Func): + pass + + +class Floor(Func): + pass + + +class Greatest(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class If(Func): + arg_types = {"this": True, "true": True, "false": False} + + +class IfNull(Func): + arg_types = {"this": True, "expression": False} + _sql_names = ["IFNULL", "NVL"] + + +class Initcap(Func): + pass + + +class JSONExtract(Func): + arg_types = {"this": True, "path": True} + _sql_names = ["JSON_EXTRACT"] + + +class JSONExtractScalar(JSONExtract): + _sql_names = ["JSON_EXTRACT_SCALAR"] + + +class JSONBExtract(JSONExtract): + _sql_names = ["JSONB_EXTRACT"] + + +class JSONBExtractScalar(JSONExtract): + _sql_names = ["JSONB_EXTRACT_SCALAR"] + + +class Least(Func): + arg_types = {"this": True, "expressions": True} + is_var_len_args = True + + +class Length(Func): + pass + + +class Levenshtein(Func): + arg_types = {"this": True, "expression": False} + + +class Ln(Func): + pass + + +class Log(Func): + arg_types = {"this": True, "expression": False} + + +class Log2(Func): + pass + + +class Log10(Func): + pass + + +class Lower(Func): + pass + + +class Map(Func): + arg_types = {"keys": True, "values": True} + + +class Max(AggFunc): + pass + + +class Min(AggFunc): + pass + + +class Month(Func): + pass + + +class Nvl2(Func): + arg_types = {"this": True, "true": True, "false": False} + + +class Posexplode(Func): + pass + + +class Pow(Func): + arg_types = {"this": True, "power": True} + _sql_names = ["POWER", "POW"] + + +class Quantile(AggFunc): + arg_types = {"this": True, "quantile": True} + + +class Reduce(Func): + arg_types = {"this": True, "initial": True, "merge": True, "finish": True} + + +class RegexpLike(Func): + arg_types = {"this": True, "expression": True} + + +class RegexpSplit(Func): + arg_types = {"this": True, "expression": True} + + +class Round(Func): + arg_types = {"this": True, "decimals": False} + + +class SafeDivide(Func): + arg_types = {"this": True, "expression": True} + + +class SetAgg(AggFunc): + pass + + +class SortArray(Func): + arg_types = {"this": True, "asc": False} + + +class Split(Func): + arg_types = {"this": True, "expression": True} + + +class Substring(Func): + arg_types = {"this": True, "start": True, "length": False} + + +class StrPosition(Func): + arg_types = {"this": True, "substr": True, "position": False} + + +class StrToDate(Func): + arg_types = {"this": True, "format": True} + + +class StrToTime(Func): + arg_types = {"this": True, "format": True} + + +class StrToUnix(Func): + arg_types = {"this": True, "format": True} + + +class Struct(Func): + arg_types = {"expressions": True} + is_var_len_args = True + + +class StructExtract(Func): + arg_types = {"this": True, "expression": True} + + +class Sum(AggFunc): + pass + + +class Sqrt(Func): + pass + + +class Stddev(AggFunc): + pass + + +class StddevPop(AggFunc): + pass + + +class StddevSamp(AggFunc): + pass + + +class TimeToStr(Func): + arg_types = {"this": True, "format": True} + + +class TimeToTimeStr(Func): + pass + + +class TimeToUnix(Func): + pass + + +class TimeStrToDate(Func): + pass + + +class TimeStrToTime(Func): + pass + + +class TimeStrToUnix(Func): + pass + + +class TsOrDsAdd(Func, TimeUnit): + arg_types = {"this": True, "expression": True, "unit": False} + + +class TsOrDsToDateStr(Func): + pass + + +class TsOrDsToDate(Func): + arg_types = {"this": True, "format": False} + + +class TsOrDiToDi(Func): + pass + + +class UnixToStr(Func): + arg_types = {"this": True, "format": True} + + +class UnixToTime(Func): + arg_types = {"this": True, "scale": False} + + SECONDS = Literal.string("seconds") + MILLIS = Literal.string("millis") + MICROS = Literal.string("micros") + + +class UnixToTimeStr(Func): + pass + + +class Upper(Func): + pass + + +class Variance(AggFunc): + _sql_names = ["VARIANCE", "VARIANCE_SAMP", "VAR_SAMP"] + + +class VariancePop(AggFunc): + _sql_names = ["VARIANCE_POP", "VAR_POP"] + + +class Week(Func): + arg_types = {"this": True, "mode": False} + + +class Year(Func): + pass + + +def _norm_args(expression): + args = {} + + for k, arg in expression.args.items(): + if isinstance(arg, list): + arg = [_norm_arg(a) for a in arg] + else: + arg = _norm_arg(arg) + + if arg is not None: + args[k] = arg + + return args + + +def _norm_arg(arg): + return arg.lower() if isinstance(arg, str) else arg + + +def _all_functions(): + return [ + obj + for _, obj in inspect.getmembers( + sys.modules[__name__], + lambda obj: inspect.isclass(obj) + and issubclass(obj, Func) + and obj not in (AggFunc, Anonymous, Func), + ) + ] + + +ALL_FUNCTIONS = _all_functions() + + +def maybe_parse( + sql_or_expression, + *, + into=None, + dialect=None, + prefix=None, + **opts, +): + """Gracefully handle a possible string or expression. + + Example: + >>> maybe_parse("1") + (LITERAL this: 1, is_string: False) + >>> maybe_parse(to_identifier("x")) + (IDENTIFIER this: x, quoted: False) + + Args: + sql_or_expression (str or Expression): the SQL code string or an expression + into (Expression): the SQLGlot Expression to parse into + dialect (str): the dialect used to parse the input expressions (in the case that an + input expression is a SQL string). + prefix (str): a string to prefix the sql with before it gets parsed + (automatically includes a space) + **opts: other options to use to parse the input expressions (again, in the case + that an input expression is a SQL string). + + Returns: + Expression: the parsed or given expression. + """ + if isinstance(sql_or_expression, Expression): + return sql_or_expression + + import sqlglot + + sql = str(sql_or_expression) + if prefix: + sql = f"{prefix} {sql}" + return sqlglot.parse_one(sql, read=dialect, into=into, **opts) + + +def _maybe_copy(instance, copy=True): + return instance.copy() if copy else instance + + +def _is_wrong_expression(expression, into): + return isinstance(expression, Expression) and not isinstance(expression, into) + + +def _apply_builder( + expression, + instance, + arg, + copy=True, + prefix=None, + into=None, + dialect=None, + **opts, +): + if _is_wrong_expression(expression, into): + expression = into(this=expression) + instance = _maybe_copy(instance, copy) + expression = maybe_parse( + sql_or_expression=expression, + prefix=prefix, + into=into, + dialect=dialect, + **opts, + ) + instance.set(arg, expression) + return instance + + +def _apply_child_list_builder( + *expressions, + instance, + arg, + append=True, + copy=True, + prefix=None, + into=None, + dialect=None, + properties=None, + **opts, +): + instance = _maybe_copy(instance, copy) + parsed = [] + for expression in expressions: + if _is_wrong_expression(expression, into): + expression = into(expressions=[expression]) + expression = maybe_parse( + expression, + into=into, + dialect=dialect, + prefix=prefix, + **opts, + ) + parsed.extend(expression.expressions) + + existing = instance.args.get(arg) + if append and existing: + parsed = existing.expressions + parsed + + child = into(expressions=parsed) + for k, v in (properties or {}).items(): + child.set(k, v) + instance.set(arg, child) + return instance + + +def _apply_list_builder( + *expressions, + instance, + arg, + append=True, + copy=True, + prefix=None, + into=None, + dialect=None, + **opts, +): + inst = _maybe_copy(instance, copy) + + expressions = [ + maybe_parse( + sql_or_expression=expression, + into=into, + prefix=prefix, + dialect=dialect, + **opts, + ) + for expression in expressions + ] + + existing_expressions = inst.args.get(arg) + if append and existing_expressions: + expressions = existing_expressions + expressions + + inst.set(arg, expressions) + return inst + + +def _apply_conjunction_builder( + *expressions, + instance, + arg, + into=None, + append=True, + copy=True, + dialect=None, + **opts, +): + expressions = [exp for exp in expressions if exp is not None and exp != ""] + if not expressions: + return instance + + inst = _maybe_copy(instance, copy) + + existing = inst.args.get(arg) + if append and existing is not None: + expressions = [existing.this if into else existing] + list(expressions) + + node = and_(*expressions, dialect=dialect, **opts) + + inst.set(arg, into(this=node) if into else node) + return inst + + +def _combine(expressions, operator, dialect=None, **opts): + expressions = [ + condition(expression, dialect=dialect, **opts) for expression in expressions + ] + this = expressions[0] + if expressions[1:]: + this = _wrap_operator(this) + for expression in expressions[1:]: + this = operator(this=this, expression=_wrap_operator(expression)) + return this + + +def _wrap_operator(expression): + if isinstance(expression, (And, Or, Not)): + expression = Paren(this=expression) + return expression + + +def select(*expressions, dialect=None, **opts): + """ + Initializes a syntax tree from one or multiple SELECT expressions. + + Example: + >>> select("col1", "col2").from_("tbl").sql() + 'SELECT col1, col2 FROM tbl' + + Args: + *expressions (str or Expression): the SQL code string to parse as the expressions of a + SELECT statement. If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expressions (in the case that an + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that an input expression is a SQL string). + + Returns: + Select: the syntax tree for the SELECT statement. + """ + return Select().select(*expressions, dialect=dialect, **opts) + + +def from_(*expressions, dialect=None, **opts): + """ + Initializes a syntax tree from a FROM expression. + + Example: + >>> from_("tbl").select("col1", "col2").sql() + 'SELECT col1, col2 FROM tbl' + + Args: + *expressions (str or Expression): the SQL code string to parse as the FROM expressions of a + SELECT statement. If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression (in the case that the + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that the input expression is a SQL string). + + Returns: + Select: the syntax tree for the SELECT statement. + """ + return Select().from_(*expressions, dialect=dialect, **opts) + + +def condition(expression, dialect=None, **opts): + """ + Initialize a logical condition expression. + + Example: + >>> condition("x=1").sql() + 'x = 1' + + This is helpful for composing larger logical syntax trees: + >>> where = condition("x=1") + >>> where = where.and_("y=1") + >>> Select().from_("tbl").select("*").where(where).sql() + 'SELECT * FROM tbl WHERE x = 1 AND y = 1' + + Args: + *expression (str or Expression): the SQL code string to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression (in the case that the + input expression is a SQL string). + **opts: other options to use to parse the input expressions (again, in the case + that the input expression is a SQL string). + + Returns: + Condition: the expression + """ + return maybe_parse( + expression, + into=Condition, + dialect=dialect, + **opts, + ) + + +def and_(*expressions, dialect=None, **opts): + """ + Combine multiple conditions with an AND logical operator. + + Example: + >>> and_("x=1", and_("y=1", "z=1")).sql() + 'x = 1 AND (y = 1 AND z = 1)' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + And: the new condition + """ + return _combine(expressions, And, dialect, **opts) + + +def or_(*expressions, dialect=None, **opts): + """ + Combine multiple conditions with an OR logical operator. + + Example: + >>> or_("x=1", or_("y=1", "z=1")).sql() + 'x = 1 OR (y = 1 OR z = 1)' + + Args: + *expressions (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Or: the new condition + """ + return _combine(expressions, Or, dialect, **opts) + + +def not_(expression, dialect=None, **opts): + """ + Wrap a condition with a NOT operator. + + Example: + >>> not_("this_suit='black'").sql() + "NOT this_suit = 'black'" + + Args: + expression (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Not: the new condition + """ + this = condition( + expression, + dialect=dialect, + **opts, + ) + return Not(this=_wrap_operator(this)) + + +def paren(expression): + return Paren(this=expression) + + +SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") + + +def to_identifier(alias, quoted=None): + if alias is None: + return None + if isinstance(alias, Identifier): + identifier = alias + elif isinstance(alias, str): + if quoted is None: + quoted = not re.match(SAFE_IDENTIFIER_RE, alias) + identifier = Identifier(this=alias, quoted=quoted) + else: + raise ValueError( + f"Alias needs to be a string or an Identifier, got: {alias.__class__}" + ) + return identifier + + +def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): + """ + Create an Alias expression. + Expample: + >>> alias_('foo', 'bar').sql() + 'foo AS bar' + + Args: + expression (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + alias (str or Identifier): the alias name to use. If the name has + special characters it is quoted. + table (boolean): create a table alias, default false + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Alias: the aliased expression + """ + exp = maybe_parse(expression, dialect=dialect, **opts) + alias = to_identifier(alias, quoted=quoted) + alias = TableAlias(this=alias) if table else alias + + if "alias" in exp.arg_types: + exp = exp.copy() + exp.set("alias", alias) + return exp + return Alias(this=exp, alias=alias) + + +def subquery(expression, alias=None, dialect=None, **opts): + """ + Build a subquery expression. + Expample: + >>> subquery('select x from tbl', 'bar').select('x').sql() + 'SELECT x FROM (SELECT x FROM tbl) AS bar' + + Args: + expression (str or Expression): the SQL code strings to parse. + If an Expression instance is passed, this is used as-is. + alias (str or Expression): the alias name to use. + dialect (str): the dialect used to parse the input expression. + **opts: other options to use to parse the input expressions. + + Returns: + Select: a new select with the subquery expression included + """ + + expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias) + return Select().from_(expression, dialect=dialect, **opts) + + +def column(col, table=None, quoted=None): + """ + Build a Column. + Args: + col (str or Expression): column name + table (str or Expression): table name + Returns: + Column: column instance + """ + return Column( + this=to_identifier(col, quoted=quoted), + table=to_identifier(table, quoted=quoted), + ) + + +def table_(table, db=None, catalog=None, quoted=None): + """ + Build a Table. + Args: + table (str or Expression): column name + db (str or Expression): db name + catalog (str or Expression): catalog name + Returns: + Table: table instance + """ + return Table( + this=to_identifier(table, quoted=quoted), + db=to_identifier(db, quoted=quoted), + catalog=to_identifier(catalog, quoted=quoted), + ) + + +def replace_children(expression, fun): + """ + Replace children of an expression with the result of a lambda fun(child) -> exp. + """ + for k, v in expression.args.items(): + is_list_arg = isinstance(v, list) + + child_nodes = v if is_list_arg else [v] + new_child_nodes = [] + + for cn in child_nodes: + if isinstance(cn, Expression): + cns = ensure_list(fun(cn)) + for child_node in cns: + new_child_nodes.append(child_node) + child_node.parent = expression + child_node.arg_key = k + else: + new_child_nodes.append(cn) + + expression.args[k] = new_child_nodes if is_list_arg else new_child_nodes[0] + + +def column_table_names(expression): + """ + Return all table names referenced through columns in an expression. + + Example: + >>> import sqlglot + >>> column_table_names(sqlglot.parse_one("a.b AND c.d AND c.e")) + ['c', 'a'] + + Args: + expression (sqlglot.Expression): expression to find table names + + Returns: + list: A list of unique names + """ + return list(dict.fromkeys(column.table for column in expression.find_all(Column))) + + +TRUE = Boolean(this=True) +FALSE = Boolean(this=False) +NULL = Null() diff --git a/sqlglot/generator.py b/sqlglot/generator.py new file mode 100644 index 0000000..793cff0 --- /dev/null +++ b/sqlglot/generator.py @@ -0,0 +1,1124 @@ +import logging + +from sqlglot import exp +from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors +from sqlglot.helper import apply_index_offset, csv, ensure_list +from sqlglot.time import format_time +from sqlglot.tokens import TokenType + +logger = logging.getLogger("sqlglot") + + +class Generator: + """ + Generator interprets the given syntax tree and produces a SQL string as an output. + + Args + time_mapping (dict): the dictionary of custom time mappings in which the key + represents a python time format and the output the target time format + time_trie (trie): a trie of the time_mapping keys + pretty (bool): if set to True the returned string will be formatted. Default: False. + quote_start (str): specifies which starting character to use to delimit quotes. Default: '. + quote_end (str): specifies which ending character to use to delimit quotes. Default: '. + identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". + identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". + identify (bool): if set to True all identifiers will be delimited by the corresponding + character. + normalize (bool): if set to True all identifiers will lower cased + escape (str): specifies an escape character. Default: '. + pad (int): determines padding in a formatted string. Default: 2. + indent (int): determines the size of indentation in a formatted string. Default: 4. + unnest_column_only (bool): if true unnest table aliases are considered only as column aliases + normalize_functions (str): normalize function names, "upper", "lower", or None + Default: "upper" + alias_post_tablesample (bool): if the table alias comes after tablesample + Default: False + unsupported_level (ErrorLevel): determines the generator's behavior when it encounters + unsupported expressions. Default ErrorLevel.WARN. + null_ordering (str): Indicates the default null ordering method to use if not explicitly set. + Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". + Default: "nulls_are_small" + max_unsupported (int): Maximum number of unsupported messages to include in a raised UnsupportedError. + This is only relevant if unsupported_level is ErrorLevel.RAISE. + Default: 3 + """ + + TRANSFORMS = { + exp.AnonymousProperty: lambda self, e: self.property_sql(e), + exp.AutoIncrementProperty: lambda self, e: f"AUTO_INCREMENT={self.sql(e, 'value')}", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", + exp.CollateProperty: lambda self, e: f"COLLATE={self.sql(e, 'value')}", + exp.DateAdd: lambda self, e: f"DATE_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')})", + exp.EngineProperty: lambda self, e: f"ENGINE={self.sql(e, 'value')}", + exp.FileFormatProperty: lambda self, e: f"FORMAT={self.sql(e, 'value')}", + exp.LocationProperty: lambda self, e: f"LOCATION {self.sql(e, 'value')}", + exp.PartitionedByProperty: lambda self, e: f"PARTITIONED_BY={self.sql(e.args['value'])}", + exp.SchemaCommentProperty: lambda self, e: f"COMMENT={self.sql(e, 'value')}", + exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT={self.sql(e, 'value')}", + exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e, 'unit')})", + } + + NULL_ORDERING_SUPPORTED = True + + TYPE_MAPPING = { + exp.DataType.Type.NCHAR: "CHAR", + exp.DataType.Type.NVARCHAR: "VARCHAR", + } + + TOKEN_MAPPING = {} + + STRUCT_DELIMITER = ("<", ">") + + ROOT_PROPERTIES = [ + exp.AutoIncrementProperty, + exp.CharacterSetProperty, + exp.CollateProperty, + exp.EngineProperty, + exp.SchemaCommentProperty, + ] + WITH_PROPERTIES = [ + exp.AnonymousProperty, + exp.FileFormatProperty, + exp.PartitionedByProperty, + exp.TableFormatProperty, + ] + + __slots__ = ( + "time_mapping", + "time_trie", + "pretty", + "configured_pretty", + "quote_start", + "quote_end", + "identifier_start", + "identifier_end", + "identify", + "normalize", + "escape", + "pad", + "index_offset", + "unnest_column_only", + "alias_post_tablesample", + "normalize_functions", + "unsupported_level", + "unsupported_messages", + "null_ordering", + "max_unsupported", + "_indent", + "_replace_backslash", + "_escaped_quote_end", + ) + + def __init__( + self, + time_mapping=None, + time_trie=None, + pretty=None, + quote_start=None, + quote_end=None, + identifier_start=None, + identifier_end=None, + identify=False, + normalize=False, + escape=None, + pad=2, + indent=2, + index_offset=0, + unnest_column_only=False, + alias_post_tablesample=False, + normalize_functions="upper", + unsupported_level=ErrorLevel.WARN, + null_ordering=None, + max_unsupported=3, + ): + import sqlglot + + self.time_mapping = time_mapping or {} + self.time_trie = time_trie + self.pretty = pretty if pretty is not None else sqlglot.pretty + self.configured_pretty = self.pretty + self.quote_start = quote_start or "'" + self.quote_end = quote_end or "'" + self.identifier_start = identifier_start or '"' + self.identifier_end = identifier_end or '"' + self.identify = identify + self.normalize = normalize + self.escape = escape or "'" + self.pad = pad + self.index_offset = index_offset + self.unnest_column_only = unnest_column_only + self.alias_post_tablesample = alias_post_tablesample + self.normalize_functions = normalize_functions + self.unsupported_level = unsupported_level + self.unsupported_messages = [] + self.max_unsupported = max_unsupported + self.null_ordering = null_ordering + self._indent = indent + self._replace_backslash = self.escape == "\\" + self._escaped_quote_end = self.escape + self.quote_end + + def generate(self, expression): + """ + Generates a SQL string by interpreting the given syntax tree. + + Args + expression (Expression): the syntax tree. + + Returns + the SQL string. + """ + self.unsupported_messages = [] + sql = self.sql(expression).strip() + + if self.unsupported_level == ErrorLevel.IGNORE: + return sql + + if self.unsupported_level == ErrorLevel.WARN: + for msg in self.unsupported_messages: + logger.warning(msg) + elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: + raise UnsupportedError( + concat_errors(self.unsupported_messages, self.max_unsupported) + ) + + return sql + + def unsupported(self, message): + if self.unsupported_level == ErrorLevel.IMMEDIATE: + raise UnsupportedError(message) + self.unsupported_messages.append(message) + + def sep(self, sep=" "): + return f"{sep.strip()}\n" if self.pretty else sep + + def seg(self, sql, sep=" "): + return f"{self.sep(sep)}{sql}" + + def wrap(self, expression): + this_sql = self.indent( + self.sql(expression) + if isinstance(expression, (exp.Select, exp.Union)) + else self.sql(expression, "this"), + level=1, + pad=0, + ) + return f"({self.sep('')}{this_sql}{self.seg(')', sep='')}" + + def no_identify(self, func): + original = self.identify + self.identify = False + result = func() + self.identify = original + return result + + def normalize_func(self, name): + if self.normalize_functions == "upper": + return name.upper() + if self.normalize_functions == "lower": + return name.lower() + return name + + def indent(self, sql, level=0, pad=None, skip_first=False, skip_last=False): + if not self.pretty: + return sql + + pad = self.pad if pad is None else pad + lines = sql.split("\n") + + return "\n".join( + line + if (skip_first and i == 0) or (skip_last and i == len(lines) - 1) + else f"{' ' * (level * self._indent + pad)}{line}" + for i, line in enumerate(lines) + ) + + def sql(self, expression, key=None): + if not expression: + return "" + + if isinstance(expression, str): + return expression + + if key: + return self.sql(expression.args.get(key)) + + transform = self.TRANSFORMS.get(expression.__class__) + + if callable(transform): + return transform(self, expression) + if transform: + return transform + + if not isinstance(expression, exp.Expression): + raise ValueError( + f"Expected an Expression. Received {type(expression)}: {expression}" + ) + + exp_handler_name = f"{expression.key}_sql" + if hasattr(self, exp_handler_name): + return getattr(self, exp_handler_name)(expression) + + if isinstance(expression, exp.Func): + return self.function_fallback_sql(expression) + + raise ValueError(f"Unsupported expression type {expression.__class__.__name__}") + + def annotation_sql(self, expression): + return self.sql(expression, "expression") + + def uncache_sql(self, expression): + table = self.sql(expression, "this") + exists_sql = " IF EXISTS" if expression.args.get("exists") else "" + return f"UNCACHE TABLE{exists_sql} {table}" + + def cache_sql(self, expression): + lazy = " LAZY" if expression.args.get("lazy") else "" + table = self.sql(expression, "this") + options = expression.args.get("options") + options = ( + f" OPTIONS({self.sql(options[0])} = {self.sql(options[1])})" + if options + else "" + ) + sql = self.sql(expression, "expression") + sql = f" AS{self.sep()}{sql}" if sql else "" + sql = f"CACHE{lazy} TABLE {table}{options}{sql}" + return self.prepend_ctes(expression, sql) + + def characterset_sql(self, expression): + if isinstance(expression.parent, exp.Cast): + return f"CHAR CHARACTER SET {self.sql(expression, 'this')}" + default = "DEFAULT " if expression.args.get("default") else "" + return f"{default}CHARACTER SET={self.sql(expression, 'this')}" + + def column_sql(self, expression): + return ".".join( + part + for part in [ + self.sql(expression, "db"), + self.sql(expression, "table"), + self.sql(expression, "this"), + ] + if part + ) + + def columndef_sql(self, expression): + column = self.sql(expression, "this") + kind = self.sql(expression, "kind") + constraints = self.expressions( + expression, key="constraints", sep=" ", flat=True + ) + + if not constraints: + return f"{column} {kind}" + return f"{column} {kind} {constraints}" + + def columnconstraint_sql(self, expression): + this = self.sql(expression, "this") + kind_sql = self.sql(expression, "kind") + return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql + + def autoincrementcolumnconstraint_sql(self, _): + return self.token_sql(TokenType.AUTO_INCREMENT) + + def checkcolumnconstraint_sql(self, expression): + this = self.sql(expression, "this") + return f"CHECK ({this})" + + def commentcolumnconstraint_sql(self, expression): + comment = self.sql(expression, "this") + return f"COMMENT {comment}" + + def collatecolumnconstraint_sql(self, expression): + collate = self.sql(expression, "this") + return f"COLLATE {collate}" + + def defaultcolumnconstraint_sql(self, expression): + default = self.sql(expression, "this") + return f"DEFAULT {default}" + + def notnullcolumnconstraint_sql(self, _): + return "NOT NULL" + + def primarykeycolumnconstraint_sql(self, _): + return "PRIMARY KEY" + + def uniquecolumnconstraint_sql(self, _): + return "UNIQUE" + + def create_sql(self, expression): + this = self.sql(expression, "this") + kind = self.sql(expression, "kind").upper() + expression_sql = self.sql(expression, "expression") + expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + replace = " OR REPLACE" if expression.args.get("replace") else "" + exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" + unique = " UNIQUE" if expression.args.get("unique") else "" + properties = self.sql(expression, "properties") + + expression_sql = f"CREATE{replace}{temporary}{unique} {kind}{exists_sql} {this}{properties} {expression_sql}" + return self.prepend_ctes(expression, expression_sql) + + def prepend_ctes(self, expression, sql): + with_ = self.sql(expression, "with") + if with_: + sql = f"{with_}{self.sep()}{sql}" + return sql + + def with_sql(self, expression): + sql = self.expressions(expression, flat=True) + recursive = "RECURSIVE " if expression.args.get("recursive") else "" + + return f"WITH {recursive}{sql}" + + def cte_sql(self, expression): + alias = self.sql(expression, "alias") + return f"{alias} AS {self.wrap(expression)}" + + def tablealias_sql(self, expression): + alias = self.sql(expression, "this") + columns = self.expressions(expression, key="columns", flat=True) + columns = f"({columns})" if columns else "" + return f"{alias}{columns}" + + def bitstring_sql(self, expression): + return f"b'{self.sql(expression, 'this')}'" + + def datatype_sql(self, expression): + type_value = expression.this + type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) + nested = "" + interior = self.expressions(expression, flat=True) + if interior: + nested = ( + f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}" + if expression.args.get("nested") + else f"({interior})" + ) + return f"{type_sql}{nested}" + + def delete_sql(self, expression): + this = self.sql(expression, "this") + where_sql = self.sql(expression, "where") + sql = f"DELETE FROM {this}{where_sql}" + return self.prepend_ctes(expression, sql) + + def drop_sql(self, expression): + this = self.sql(expression, "this") + kind = expression.args["kind"] + exists_sql = " IF EXISTS " if expression.args.get("exists") else " " + return f"DROP {kind}{exists_sql}{this}" + + def except_sql(self, expression): + return self.prepend_ctes( + expression, + self.set_operation(expression, self.except_op(expression)), + ) + + def except_op(self, expression): + return f"EXCEPT{'' if expression.args.get('distinct') else ' ALL'}" + + def fetch_sql(self, expression): + direction = expression.args.get("direction") + direction = f" {direction.upper()}" if direction else "" + count = expression.args.get("count") + count = f" {count}" if count else "" + return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY" + + def filter_sql(self, expression): + this = self.sql(expression, "this") + where = self.sql(expression, "expression")[1:] # where has a leading space + return f"{this} FILTER({where})" + + def hint_sql(self, expression): + if self.sql(expression, "this"): + self.unsupported("Hints are not supported") + return "" + + def index_sql(self, expression): + this = self.sql(expression, "this") + table = self.sql(expression, "table") + columns = self.sql(expression, "columns") + return f"{this} ON {table} {columns}" + + def identifier_sql(self, expression): + value = expression.name + value = value.lower() if self.normalize else value + if expression.args.get("quoted") or self.identify: + return f"{self.identifier_start}{value}{self.identifier_end}" + return value + + def partition_sql(self, expression): + keys = csv( + *[ + f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] + for k, v in expression.args.get("this") + ] + ) + return f"PARTITION({keys})" + + def properties_sql(self, expression): + root_properties = [] + with_properties = [] + + for p in expression.expressions: + p_class = p.__class__ + if p_class in self.ROOT_PROPERTIES: + root_properties.append(p) + elif p_class in self.WITH_PROPERTIES: + with_properties.append(p) + + return self.root_properties( + exp.Properties(expressions=root_properties) + ) + self.with_properties(exp.Properties(expressions=with_properties)) + + def root_properties(self, properties): + if properties.expressions: + return self.sep() + self.expressions( + properties, + indent=False, + sep=" ", + ) + return "" + + def properties(self, properties, prefix="", sep=", "): + if properties.expressions: + expressions = self.expressions( + properties, + sep=sep, + indent=False, + ) + return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" + return "" + + def with_properties(self, properties): + return self.properties( + properties, + prefix="WITH", + ) + + def property_sql(self, expression): + key = expression.name + value = self.sql(expression, "value") + return f"{key} = {value}" + + def insert_sql(self, expression): + kind = "OVERWRITE TABLE" if expression.args.get("overwrite") else "INTO" + this = self.sql(expression, "this") + exists = " IF EXISTS " if expression.args.get("exists") else " " + partition_sql = ( + self.sql(expression, "partition") + if expression.args.get("partition") + else "" + ) + expression_sql = self.sql(expression, "expression") + sep = self.sep() if partition_sql else "" + sql = f"INSERT {kind} {this}{exists}{partition_sql}{sep}{expression_sql}" + return self.prepend_ctes(expression, sql) + + def intersect_sql(self, expression): + return self.prepend_ctes( + expression, + self.set_operation(expression, self.intersect_op(expression)), + ) + + def intersect_op(self, expression): + return f"INTERSECT{'' if expression.args.get('distinct') else ' ALL'}" + + def introducer_sql(self, expression): + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + + def table_sql(self, expression): + return ".".join( + part + for part in [ + self.sql(expression, "catalog"), + self.sql(expression, "db"), + self.sql(expression, "this"), + ] + if part + ) + + def tablesample_sql(self, expression): + if self.alias_post_tablesample and isinstance(expression.this, exp.Alias): + this = self.sql(expression.this, "this") + alias = f" AS {self.sql(expression.this, 'alias')}" + else: + this = self.sql(expression, "this") + alias = "" + method = self.sql(expression, "method") + method = f" {method.upper()} " if method else "" + numerator = self.sql(expression, "bucket_numerator") + denominator = self.sql(expression, "bucket_denominator") + field = self.sql(expression, "bucket_field") + field = f" ON {field}" if field else "" + bucket = f"BUCKET {numerator} OUT OF {denominator}{field}" if numerator else "" + percent = self.sql(expression, "percent") + percent = f"{percent} PERCENT" if percent else "" + rows = self.sql(expression, "rows") + rows = f"{rows} ROWS" if rows else "" + size = self.sql(expression, "size") + return f"{this} TABLESAMPLE{method}({bucket}{percent}{rows}{size}){alias}" + + def tuple_sql(self, expression): + return f"({self.expressions(expression, flat=True)})" + + def update_sql(self, expression): + this = self.sql(expression, "this") + set_sql = self.expressions(expression, flat=True) + from_sql = self.sql(expression, "from") + where_sql = self.sql(expression, "where") + sql = f"UPDATE {this} SET {set_sql}{from_sql}{where_sql}" + return self.prepend_ctes(expression, sql) + + def values_sql(self, expression): + return f"VALUES{self.seg('')}{self.expressions(expression)}" + + def var_sql(self, expression): + return self.sql(expression, "this") + + def from_sql(self, expression): + expressions = self.expressions(expression, flat=True) + return f"{self.seg('FROM')} {expressions}" + + def group_sql(self, expression): + group_by = self.op_expressions("GROUP BY", expression) + grouping_sets = self.expressions(expression, key="grouping_sets", indent=False) + grouping_sets = ( + f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" + if grouping_sets + else "" + ) + cube = self.expressions(expression, key="cube", indent=False) + cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else "" + rollup = self.expressions(expression, key="rollup", indent=False) + rollup = f"{self.seg('ROLLUP')} {self.wrap(rollup)}" if rollup else "" + return f"{group_by}{grouping_sets}{cube}{rollup}" + + def having_sql(self, expression): + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('HAVING')}{self.sep()}{this}" + + def join_sql(self, expression): + op_sql = self.seg( + " ".join(op for op in (expression.side, expression.kind, "JOIN") if op) + ) + on_sql = self.sql(expression, "on") + using = expression.args.get("using") + + if not on_sql and using: + on_sql = csv(*(self.sql(column) for column in using)) + + if on_sql: + on_sql = self.indent(on_sql, skip_first=True) + space = self.seg(" " * self.pad) if self.pretty else " " + if using: + on_sql = f"{space}USING ({on_sql})" + else: + on_sql = f"{space}ON {on_sql}" + + expression_sql = self.sql(expression, "expression") + this_sql = self.sql(expression, "this") + return f"{expression_sql}{op_sql} {this_sql}{on_sql}" + + def lambda_sql(self, expression): + args = self.expressions(expression, flat=True) + args = f"({args})" if len(args.split(",")) > 1 else args + return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}") + + def lateral_sql(self, expression): + this = self.sql(expression, "this") + op_sql = self.seg( + f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}" + ) + alias = expression.args["alias"] + table = alias.name + table = f" {table}" if table else table + columns = self.expressions(alias, key="columns", flat=True) + columns = f" AS {columns}" if columns else "" + return f"{op_sql}{self.sep()}{this}{table}{columns}" + + def limit_sql(self, expression): + this = self.sql(expression, "this") + return f"{this}{self.seg('LIMIT')} {self.sql(expression, 'expression')}" + + def offset_sql(self, expression): + this = self.sql(expression, "this") + return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + + def literal_sql(self, expression): + text = expression.this or "" + if expression.is_string: + if self._replace_backslash: + text = text.replace("\\", "\\\\") + text = text.replace(self.quote_end, self._escaped_quote_end) + return f"{self.quote_start}{text}{self.quote_end}" + return text + + def null_sql(self, *_): + return "NULL" + + def boolean_sql(self, expression): + return "TRUE" if expression.this else "FALSE" + + def order_sql(self, expression, flat=False): + this = self.sql(expression, "this") + this = f"{this} " if this else this + return self.op_expressions(f"{this}ORDER BY", expression, flat=this or flat) + + def cluster_sql(self, expression): + return self.op_expressions("CLUSTER BY", expression) + + def distribute_sql(self, expression): + return self.op_expressions("DISTRIBUTE BY", expression) + + def sort_sql(self, expression): + return self.op_expressions("SORT BY", expression) + + def ordered_sql(self, expression): + desc = expression.args.get("desc") + asc = not desc + nulls_first = expression.args.get("nulls_first") + nulls_last = not nulls_first + nulls_are_large = self.null_ordering == "nulls_are_large" + nulls_are_small = self.null_ordering == "nulls_are_small" + nulls_are_last = self.null_ordering == "nulls_are_last" + + sort_order = " DESC" if desc else "" + nulls_sort_change = "" + if nulls_first and ( + (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last + ): + nulls_sort_change = " NULLS FIRST" + elif ( + nulls_last + and ((asc and nulls_are_small) or (desc and nulls_are_large)) + and not nulls_are_last + ): + nulls_sort_change = " NULLS LAST" + + if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED: + self.unsupported( + "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect" + ) + nulls_sort_change = "" + + return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}" + + def query_modifiers(self, expression, *sqls): + return csv( + *sqls, + *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins", [])], + self.sql(expression, "where"), + self.sql(expression, "group"), + self.sql(expression, "having"), + self.sql(expression, "qualify"), + self.sql(expression, "window"), + self.sql(expression, "distribute"), + self.sql(expression, "sort"), + self.sql(expression, "cluster"), + self.sql(expression, "order"), + self.sql(expression, "limit"), + self.sql(expression, "offset"), + sep="", + ) + + def select_sql(self, expression): + hint = self.sql(expression, "hint") + distinct = self.sql(expression, "distinct") + distinct = f" {distinct}" if distinct else "" + expressions = self.expressions(expression) + expressions = f"{self.sep()}{expressions}" if expressions else expressions + sql = self.query_modifiers( + expression, + f"SELECT{hint}{distinct}{expressions}", + self.sql(expression, "from"), + ) + return self.prepend_ctes(expression, sql) + + def schema_sql(self, expression): + this = self.sql(expression, "this") + this = f"{this} " if this else "" + sql = f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + return f"{this}{sql}" + + def star_sql(self, expression): + except_ = self.expressions(expression, key="except", flat=True) + except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" + replace = self.expressions(expression, key="replace", flat=True) + replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" + return f"*{except_}{replace}" + + def structkwarg_sql(self, expression): + return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + + def placeholder_sql(self, *_): + return "?" + + def subquery_sql(self, expression): + alias = self.sql(expression, "alias") + + return self.query_modifiers( + expression, + self.wrap(expression), + f" AS {alias}" if alias else "", + ) + + def qualify_sql(self, expression): + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('QUALIFY')}{self.sep()}{this}" + + def union_sql(self, expression): + return self.prepend_ctes( + expression, + self.set_operation(expression, self.union_op(expression)), + ) + + def union_op(self, expression): + return f"UNION{'' if expression.args.get('distinct') else ' ALL'}" + + def unnest_sql(self, expression): + args = self.expressions(expression, flat=True) + alias = expression.args.get("alias") + if alias and self.unnest_column_only: + columns = alias.columns + alias = self.sql(columns[0]) if columns else "" + else: + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else alias + ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else "" + return f"UNNEST({args}){ordinality}{alias}" + + def where_sql(self, expression): + this = self.indent(self.sql(expression, "this")) + return f"{self.seg('WHERE')}{self.sep()}{this}" + + def window_sql(self, expression): + this = self.sql(expression, "this") + partition = self.expressions(expression, key="partition_by", flat=True) + partition = f"PARTITION BY {partition}" if partition else "" + order = expression.args.get("order") + order_sql = self.order_sql(order, flat=True) if order else "" + partition_sql = partition + " " if partition and order else partition + spec = expression.args.get("spec") + spec_sql = " " + self.window_spec_sql(spec) if spec else "" + alias = self.sql(expression, "alias") + if expression.arg_key == "window": + this = this = f"{self.seg('WINDOW')} {this} AS" + else: + this = f"{this} OVER" + + if not partition and not order and not spec and alias: + return f"{this} {alias}" + + return f"{this} ({alias}{partition_sql}{order_sql}{spec_sql})" + + def window_spec_sql(self, expression): + kind = self.sql(expression, "kind") + start = csv( + self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" " + ) + end = ( + csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") + or "CURRENT ROW" + ) + return f"{kind} BETWEEN {start} AND {end}" + + def withingroup_sql(self, expression): + this = self.sql(expression, "this") + expression = self.sql(expression, "expression")[1:] # order has a leading space + return f"{this} WITHIN GROUP ({expression})" + + def between_sql(self, expression): + this = self.sql(expression, "this") + low = self.sql(expression, "low") + high = self.sql(expression, "high") + return f"{this} BETWEEN {low} AND {high}" + + def bracket_sql(self, expression): + expressions = apply_index_offset(expression.expressions, self.index_offset) + expressions = ", ".join(self.sql(e) for e in expressions) + + return f"{self.sql(expression, 'this')}[{expressions}]" + + def all_sql(self, expression): + return f"ALL {self.wrap(expression)}" + + def any_sql(self, expression): + return f"ANY {self.wrap(expression)}" + + def exists_sql(self, expression): + return f"EXISTS{self.wrap(expression)}" + + def case_sql(self, expression): + this = self.indent(self.sql(expression, "this"), skip_first=True) + this = f" {this}" if this else "" + ifs = [] + + for e in expression.args["ifs"]: + ifs.append(self.indent(f"WHEN {self.sql(e, 'this')}")) + ifs.append(self.indent(f"THEN {self.sql(e, 'true')}")) + + if expression.args.get("default") is not None: + ifs.append(self.indent(f"ELSE {self.sql(expression, 'default')}")) + + ifs = "".join(self.seg(self.indent(e, skip_first=True)) for e in ifs) + statement = f"CASE{this}{ifs}{self.seg('END')}" + return statement + + def constraint_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"CONSTRAINT {this} {expressions}" + + def extract_sql(self, expression): + this = self.sql(expression, "this") + expression_sql = self.sql(expression, "expression") + return f"EXTRACT({this} FROM {expression_sql})" + + def check_sql(self, expression): + this = self.sql(expression, key="this") + return f"CHECK ({this})" + + def foreignkey_sql(self, expression): + expressions = self.expressions(expression, flat=True) + reference = self.sql(expression, "reference") + reference = f" {reference}" if reference else "" + delete = self.sql(expression, "delete") + delete = f" ON DELETE {delete}" if delete else "" + update = self.sql(expression, "update") + update = f" ON UPDATE {update}" if update else "" + return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" + + def unique_sql(self, expression): + columns = self.expressions(expression, key="expressions") + return f"UNIQUE ({columns})" + + def if_sql(self, expression): + return self.case_sql( + exp.Case(ifs=[expression], default=expression.args.get("false")) + ) + + def in_sql(self, expression): + query = expression.args.get("query") + unnest = expression.args.get("unnest") + if query: + in_sql = self.wrap(query) + elif unnest: + in_sql = self.in_unnest_op(unnest) + else: + in_sql = f"({self.expressions(expression, flat=True)})" + return f"{self.sql(expression, 'this')} IN {in_sql}" + + def in_unnest_op(self, unnest): + return f"(SELECT {self.sql(unnest)})" + + def interval_sql(self, expression): + return f"INTERVAL {self.sql(expression, 'this')} {self.sql(expression, 'unit')}" + + def reference_sql(self, expression): + this = self.sql(expression, "this") + expressions = self.expressions(expression, flat=True) + return f"REFERENCES {this}({expressions})" + + def anonymous_sql(self, expression): + args = self.indent( + self.expressions(expression, flat=True), skip_first=True, skip_last=True + ) + return f"{self.normalize_func(self.sql(expression, 'this'))}({args})" + + def paren_sql(self, expression): + if isinstance(expression.unnest(), exp.Select): + return self.wrap(expression) + sql = self.seg(self.indent(self.sql(expression, "this")), sep="") + return f"({sql}{self.seg(')', sep='')}" + + def neg_sql(self, expression): + return f"-{self.sql(expression, 'this')}" + + def not_sql(self, expression): + return f"NOT {self.sql(expression, 'this')}" + + def alias_sql(self, expression): + to_sql = self.sql(expression, "alias") + to_sql = f" AS {to_sql}" if to_sql else "" + return f"{self.sql(expression, 'this')}{to_sql}" + + def aliases_sql(self, expression): + return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})" + + def attimezone_sql(self, expression): + this = self.sql(expression, "this") + zone = self.sql(expression, "zone") + return f"{this} AT TIME ZONE {zone}" + + def add_sql(self, expression): + return self.binary(expression, "+") + + def and_sql(self, expression): + return self.connector_sql(expression, "AND") + + def connector_sql(self, expression, op): + if not self.pretty: + return self.binary(expression, op) + + return f"\n{op} ".join(self.sql(e) for e in expression.flatten(unnest=False)) + + def bitwiseand_sql(self, expression): + return self.binary(expression, "&") + + def bitwiseleftshift_sql(self, expression): + return self.binary(expression, "<<") + + def bitwisenot_sql(self, expression): + return f"~{self.sql(expression, 'this')}" + + def bitwiseor_sql(self, expression): + return self.binary(expression, "|") + + def bitwiserightshift_sql(self, expression): + return self.binary(expression, ">>") + + def bitwisexor_sql(self, expression): + return self.binary(expression, "^") + + def cast_sql(self, expression): + return f"CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + + def currentdate_sql(self, expression): + zone = self.sql(expression, "this") + return f"CURRENT_DATE({zone})" if zone else "CURRENT_DATE" + + def command_sql(self, expression): + return f"{self.sql(expression, 'this').upper()} {expression.text('expression').strip()}" + + def distinct_sql(self, expression): + this = self.sql(expression, "this") + this = f" {this}" if this else "" + + on = self.sql(expression, "on") + on = f" ON {on}" if on else "" + return f"DISTINCT{this}{on}" + + def ignorenulls_sql(self, expression): + return f"{self.sql(expression, 'this')} IGNORE NULLS" + + def intdiv_sql(self, expression): + return self.sql( + exp.Cast( + this=exp.Div( + this=expression.args["this"], + expression=expression.args["expression"], + ), + to=exp.DataType(this=exp.DataType.Type.INT), + ) + ) + + def dpipe_sql(self, expression): + return self.binary(expression, "||") + + def div_sql(self, expression): + return self.binary(expression, "/") + + def dot_sql(self, expression): + return f"{self.sql(expression, 'this')}.{self.sql(expression, 'expression')}" + + def eq_sql(self, expression): + return self.binary(expression, "=") + + def escape_sql(self, expression): + return self.binary(expression, "ESCAPE") + + def gt_sql(self, expression): + return self.binary(expression, ">") + + def gte_sql(self, expression): + return self.binary(expression, ">=") + + def ilike_sql(self, expression): + return self.binary(expression, "ILIKE") + + def is_sql(self, expression): + return self.binary(expression, "IS") + + def like_sql(self, expression): + return self.binary(expression, "LIKE") + + def lt_sql(self, expression): + return self.binary(expression, "<") + + def lte_sql(self, expression): + return self.binary(expression, "<=") + + def mod_sql(self, expression): + return self.binary(expression, "%") + + def mul_sql(self, expression): + return self.binary(expression, "*") + + def neq_sql(self, expression): + return self.binary(expression, "<>") + + def or_sql(self, expression): + return self.connector_sql(expression, "OR") + + def sub_sql(self, expression): + return self.binary(expression, "-") + + def trycast_sql(self, expression): + return ( + f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})" + ) + + def binary(self, expression, op): + return ( + f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}" + ) + + def function_fallback_sql(self, expression): + args = [] + for arg_key in expression.arg_types: + arg_value = ensure_list(expression.args.get(arg_key) or []) + for a in arg_value: + args.append(self.sql(a)) + + args_str = self.indent(", ".join(args), skip_first=True, skip_last=True) + return f"{self.normalize_func(expression.sql_name())}({args_str})" + + def format_time(self, expression): + return format_time( + self.sql(expression, "format"), self.time_mapping, self.time_trie + ) + + def expressions(self, expression, key=None, flat=False, indent=True, sep=", "): + expressions = expression.args.get(key or "expressions") + + if not expressions: + return "" + + if flat: + return sep.join(self.sql(e) for e in expressions) + + expressions = self.sep(sep).join(self.sql(e) for e in expressions) + if indent: + return self.indent(expressions, skip_first=False) + return expressions + + def op_expressions(self, op, expression, flat=False): + expressions_sql = self.expressions(expression, flat=flat) + if flat: + return f"{op} {expressions_sql}" + return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" + + def set_operation(self, expression, op): + this = self.sql(expression, "this") + op = self.seg(op) + return self.query_modifiers( + expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" + ) + + def token_sql(self, token_type): + return self.TOKEN_MAPPING.get(token_type, token_type.name) diff --git a/sqlglot/helper.py b/sqlglot/helper.py new file mode 100644 index 0000000..5d90c49 --- /dev/null +++ b/sqlglot/helper.py @@ -0,0 +1,123 @@ +import logging +import re +from contextlib import contextmanager +from enum import Enum + +CAMEL_CASE_PATTERN = re.compile("(?<!^)(?=[A-Z])") +logger = logging.getLogger("sqlglot") + + +class AutoName(Enum): + def _generate_next_value_(name, _start, _count, _last_values): + return name + + +def list_get(arr, index): + try: + return arr[index] + except IndexError: + return None + + +def ensure_list(value): + if value is None: + return [] + return value if isinstance(value, (list, tuple, set)) else [value] + + +def csv(*args, sep=", "): + return sep.join(arg for arg in args if arg) + + +def apply_index_offset(expressions, offset): + if not offset or len(expressions) != 1: + return expressions + + expression = expressions[0] + + if expression.is_int: + expression = expression.copy() + logger.warning("Applying array index offset (%s)", offset) + expression.args["this"] = str(int(expression.args["this"]) + offset) + return [expression] + return expressions + + +def camel_to_snake_case(name): + return CAMEL_CASE_PATTERN.sub("_", name).upper() + + +def while_changing(expression, func): + while True: + start = hash(expression) + expression = func(expression) + if start == hash(expression): + break + return expression + + +def tsort(dag): + result = [] + + def visit(node, visited): + if node in result: + return + if node in visited: + raise ValueError("Cycle error") + + visited.add(node) + + for dep in dag.get(node, []): + visit(dep, visited) + + visited.remove(node) + result.append(node) + + for node in dag: + visit(node, set()) + + return result + + +def open_file(file_name): + """ + Open a file that may be compressed as gzip and return in newline mode. + """ + with open(file_name, "rb") as f: + gzipped = f.read(2) == b"\x1f\x8b" + + if gzipped: + import gzip + + return gzip.open(file_name, "rt", newline="") + + return open(file_name, "rt", encoding="utf-8", newline="") + + +@contextmanager +def csv_reader(table): + """ + Returns a csv reader given the expression READ_CSV(name, ['delimiter', '|', ...]) + + Args: + expression (Expression): An anonymous function READ_CSV + + Returns: + A python csv reader. + """ + file, *args = table.this.expressions + file = file.name + file = open_file(file) + + delimiter = "," + args = iter(arg.name for arg in args) + for k, v in zip(args, args): + if k == "delimiter": + delimiter = v + + try: + import csv as csv_ + + yield csv_.reader(file, delimiter=delimiter) + finally: + file.close() diff --git a/sqlglot/optimizer/__init__.py b/sqlglot/optimizer/__init__.py new file mode 100644 index 0000000..a4c4cc2 --- /dev/null +++ b/sqlglot/optimizer/__init__.py @@ -0,0 +1,2 @@ +from sqlglot.optimizer.optimizer import optimize +from sqlglot.optimizer.schema import Schema diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py new file mode 100644 index 0000000..4bfb733 --- /dev/null +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -0,0 +1,48 @@ +import itertools + +from sqlglot import alias, exp, select, table +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def eliminate_subqueries(expression): + """ + Rewrite duplicate subqueries from sqlglot AST. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 AS x, 2 AS y UNION ALL SELECT 1 AS x, 2 AS y") + >>> eliminate_subqueries(expression).sql() + 'WITH _e_0 AS (SELECT 1 AS x, 2 AS y) SELECT * FROM _e_0 UNION ALL SELECT * FROM _e_0' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + expression = simplify(expression) + queries = {} + + for scope in traverse_scope(expression): + query = scope.expression + queries[query] = queries.get(query, []) + [query] + + sequence = itertools.count() + + for query, duplicates in queries.items(): + if len(duplicates) == 1: + continue + + alias_ = f"_e_{next(sequence)}" + + for dup in duplicates: + parent = dup.parent + if isinstance(parent, exp.Subquery): + parent.replace(alias(table(alias_), parent.alias_or_name, table=True)) + elif isinstance(parent, exp.Union): + dup.replace(select("*").from_(alias_)) + + expression.with_(alias_, as_=query, copy=False) + + return expression diff --git a/sqlglot/optimizer/expand_multi_table_selects.py b/sqlglot/optimizer/expand_multi_table_selects.py new file mode 100644 index 0000000..ba562df --- /dev/null +++ b/sqlglot/optimizer/expand_multi_table_selects.py @@ -0,0 +1,16 @@ +from sqlglot import exp + + +def expand_multi_table_selects(expression): + for from_ in expression.find_all(exp.From): + parent = from_.parent + + for query in from_.expressions[1:]: + parent.join( + query, + join_type="CROSS", + copy=False, + ) + from_.expressions.remove(query) + + return expression diff --git a/sqlglot/optimizer/isolate_table_selects.py b/sqlglot/optimizer/isolate_table_selects.py new file mode 100644 index 0000000..c2e021e --- /dev/null +++ b/sqlglot/optimizer/isolate_table_selects.py @@ -0,0 +1,31 @@ +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.scope import traverse_scope + + +def isolate_table_selects(expression): + for scope in traverse_scope(expression): + if len(scope.selected_sources) == 1: + continue + + for (_, source) in scope.selected_sources.values(): + if not isinstance(source, exp.Table): + continue + + if not isinstance(source.parent, exp.Alias): + raise OptimizeError( + "Tables require an alias. Run qualify_tables optimization." + ) + + parent = source.parent + + parent.replace( + exp.select("*") + .from_( + alias(source, source.name or parent.alias, table=True), + copy=False, + ) + .subquery(parent.alias, copy=False) + ) + + return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py new file mode 100644 index 0000000..2c9f89c --- /dev/null +++ b/sqlglot/optimizer/normalize.py @@ -0,0 +1,136 @@ +from sqlglot import exp +from sqlglot.helper import while_changing +from sqlglot.optimizer.simplify import flatten, simplify, uniq_sort + + +def normalize(expression, dnf=False, max_distance=128): + """ + Rewrite sqlglot AST into conjunctive normal form. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(x AND y) OR z") + >>> normalize(expression).sql() + '(x OR z) AND (y OR z)' + + Args: + expression (sqlglot.Expression): expression to normalize + dnf (bool): rewrite in disjunctive normal form instead + max_distance (int): the maximal estimated distance from cnf to attempt conversion + Returns: + sqlglot.Expression: normalized expression + """ + expression = simplify(expression) + + expression = while_changing( + expression, lambda e: distributive_law(e, dnf, max_distance) + ) + return simplify(expression) + + +def normalized(expression, dnf=False): + ancestor, root = (exp.And, exp.Or) if dnf else (exp.Or, exp.And) + + return not any( + connector.find_ancestor(ancestor) for connector in expression.find_all(root) + ) + + +def normalization_distance(expression, dnf=False): + """ + The difference in the number of predicates between the current expression and the normalized form. + + This is used as an estimate of the cost of the conversion which is exponential in complexity. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("(a AND b) OR (c AND d)") + >>> normalization_distance(expression) + 4 + + Args: + expression (sqlglot.Expression): expression to compute distance + dnf (bool): compute to dnf distance instead + Returns: + int: difference + """ + return sum(_predicate_lengths(expression, dnf)) - ( + len(list(expression.find_all(exp.Connector))) + 1 + ) + + +def _predicate_lengths(expression, dnf): + """ + Returns a list of predicate lengths when expanded to normalized form. + + (A AND B) OR C -> [2, 2] because len(A OR C), len(B OR C). + """ + expression = expression.unnest() + + if not isinstance(expression, exp.Connector): + return [1] + + left, right = expression.args.values() + + if isinstance(expression, exp.And if dnf else exp.Or): + x = [ + a + b + for a in _predicate_lengths(left, dnf) + for b in _predicate_lengths(right, dnf) + ] + return x + return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) + + +def distributive_law(expression, dnf, max_distance): + """ + x OR (y AND z) -> (x OR y) AND (x OR z) + (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z) + """ + if isinstance(expression.unnest(), exp.Connector): + if normalization_distance(expression, dnf) > max_distance: + return expression + + to_exp, from_exp = (exp.Or, exp.And) if dnf else (exp.And, exp.Or) + + exp.replace_children(expression, lambda e: distributive_law(e, dnf, max_distance)) + + if isinstance(expression, from_exp): + a, b = expression.unnest_operands() + + from_func = exp.and_ if from_exp == exp.And else exp.or_ + to_func = exp.and_ if to_exp == exp.And else exp.or_ + + if isinstance(a, to_exp) and isinstance(b, to_exp): + if len(tuple(a.find_all(exp.Connector))) > len( + tuple(b.find_all(exp.Connector)) + ): + return _distribute(a, b, from_func, to_func) + return _distribute(b, a, from_func, to_func) + if isinstance(a, to_exp): + return _distribute(b, a, from_func, to_func) + if isinstance(b, to_exp): + return _distribute(a, b, from_func, to_func) + + return expression + + +def _distribute(a, b, from_func, to_func): + if isinstance(a, exp.Connector): + exp.replace_children( + a, + lambda c: to_func( + exp.paren(from_func(c, b.left)), + exp.paren(from_func(c, b.right)), + ), + ) + else: + a = to_func(from_func(a, b.left), from_func(a, b.right)) + + return _simplify(a) + + +def _simplify(node): + node = uniq_sort(flatten(node)) + exp.replace_children(node, _simplify) + return node diff --git a/sqlglot/optimizer/optimize_joins.py b/sqlglot/optimizer/optimize_joins.py new file mode 100644 index 0000000..40e4ab1 --- /dev/null +++ b/sqlglot/optimizer/optimize_joins.py @@ -0,0 +1,75 @@ +from sqlglot import exp +from sqlglot.helper import tsort +from sqlglot.optimizer.simplify import simplify + + +def optimize_joins(expression): + """ + Removes cross joins if possible and reorder joins based on predicate dependencies. + """ + for select in expression.find_all(exp.Select): + references = {} + cross_joins = [] + + for join in select.args.get("joins", []): + name = join.this.alias_or_name + tables = other_table_names(join, name) + + if tables: + for table in tables: + references[table] = references.get(table, []) + [join] + else: + cross_joins.append((name, join)) + + for name, join in cross_joins: + for dep in references.get(name, []): + on = dep.args["on"] + on = on.replace(simplify(on)) + + if isinstance(on, exp.Connector): + for predicate in on.flatten(): + if name in exp.column_table_names(predicate): + predicate.replace(exp.TRUE) + join.on(predicate, copy=False) + + expression = reorder_joins(expression) + expression = normalize(expression) + return expression + + +def reorder_joins(expression): + """ + Reorder joins by topological sort order based on predicate references. + """ + for from_ in expression.find_all(exp.From): + head = from_.expressions[0] + parent = from_.parent + joins = {join.this.alias_or_name: join for join in parent.args.get("joins", [])} + dag = {head.alias_or_name: []} + + for name, join in joins.items(): + dag[name] = other_table_names(join, name) + + parent.set( + "joins", + [joins[name] for name in tsort(dag) if name != head.alias_or_name], + ) + return expression + + +def normalize(expression): + """ + Remove INNER and OUTER from joins as they are optional. + """ + for join in expression.find_all(exp.Join): + if join.kind != "CROSS": + join.set("kind", None) + return expression + + +def other_table_names(join, exclude): + return [ + name + for name in (exp.column_table_names(join.args.get("on") or exp.TRUE)) + if name != exclude + ] diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py new file mode 100644 index 0000000..c03fe3c --- /dev/null +++ b/sqlglot/optimizer/optimizer.py @@ -0,0 +1,43 @@ +from sqlglot.optimizer.eliminate_subqueries import eliminate_subqueries +from sqlglot.optimizer.expand_multi_table_selects import expand_multi_table_selects +from sqlglot.optimizer.isolate_table_selects import isolate_table_selects +from sqlglot.optimizer.normalize import normalize +from sqlglot.optimizer.optimize_joins import optimize_joins +from sqlglot.optimizer.pushdown_predicates import pushdown_predicates +from sqlglot.optimizer.pushdown_projections import pushdown_projections +from sqlglot.optimizer.qualify_columns import qualify_columns +from sqlglot.optimizer.qualify_tables import qualify_tables +from sqlglot.optimizer.quote_identities import quote_identities +from sqlglot.optimizer.unnest_subqueries import unnest_subqueries + + +def optimize(expression, schema=None, db=None, catalog=None): + """ + Rewrite a sqlglot AST into an optimized form. + + Args: + expression (sqlglot.Expression): expression to optimize + schema (dict|sqlglot.optimizer.Schema): database schema. + This can either be an instance of `sqlglot.optimizer.Schema` or a mapping in one of + the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + db (str): specify the default database, as might be set by a `USE DATABASE db` statement + catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement + Returns: + sqlglot.Expression: optimized expression + """ + expression = expression.copy() + expression = qualify_tables(expression, db=db, catalog=catalog) + expression = isolate_table_selects(expression) + expression = qualify_columns(expression, schema) + expression = pushdown_projections(expression) + expression = normalize(expression) + expression = unnest_subqueries(expression) + expression = expand_multi_table_selects(expression) + expression = pushdown_predicates(expression) + expression = optimize_joins(expression) + expression = eliminate_subqueries(expression) + expression = quote_identities(expression) + return expression diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py new file mode 100644 index 0000000..e757322 --- /dev/null +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -0,0 +1,176 @@ +from sqlglot import exp +from sqlglot.optimizer.normalize import normalized +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.optimizer.simplify import simplify + + +def pushdown_predicates(expression): + """ + Rewrite sqlglot AST to pushdown predicates in FROMS and JOINS + + Example: + >>> import sqlglot + >>> sql = "SELECT * FROM (SELECT * FROM x AS x) AS y WHERE y.a = 1" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_predicates(expression).sql() + 'SELECT * FROM (SELECT * FROM x AS x WHERE y.a = 1) AS y WHERE TRUE' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + for scope in reversed(traverse_scope(expression)): + select = scope.expression + where = select.args.get("where") + if where: + pushdown(where.this, scope.selected_sources) + + # joins should only pushdown into itself, not to other joins + # so we limit the selected sources to only itself + for join in select.args.get("joins") or []: + name = join.this.alias_or_name + pushdown(join.args.get("on"), {name: scope.selected_sources[name]}) + + return expression + + +def pushdown(condition, sources): + if not condition: + return + + condition = condition.replace(simplify(condition)) + cnf_like = normalized(condition) or not normalized(condition, dnf=True) + + predicates = list( + condition.flatten() + if isinstance(condition, exp.And if cnf_like else exp.Or) + else [condition] + ) + + if cnf_like: + pushdown_cnf(predicates, sources) + else: + pushdown_dnf(predicates, sources) + + +def pushdown_cnf(predicates, scope): + """ + If the predicates are in CNF like form, we can simply replace each block in the parent. + """ + for predicate in predicates: + for node in nodes_for_predicate(predicate, scope).values(): + if isinstance(node, exp.Join): + predicate.replace(exp.TRUE) + node.on(predicate, copy=False) + break + if isinstance(node, exp.Select): + predicate.replace(exp.TRUE) + node.where(replace_aliases(node, predicate), copy=False) + + +def pushdown_dnf(predicates, scope): + """ + If the predicates are in DNF form, we can only push down conditions that are in all blocks. + Additionally, we can't remove predicates from their original form. + """ + # find all the tables that can be pushdown too + # these are tables that are referenced in all blocks of a DNF + # (a.x AND b.x) OR (a.y AND c.y) + # only table a can be push down + pushdown_tables = set() + + for a in predicates: + a_tables = set(exp.column_table_names(a)) + + for b in predicates: + a_tables &= set(exp.column_table_names(b)) + + pushdown_tables.update(a_tables) + + conditions = {} + + # for every pushdown table, find all related conditions in all predicates + # combine them with ORS + # (a.x AND and a.y AND b.x) OR (a.z AND c.y) -> (a.x AND a.y) OR (a.z) + for table in sorted(pushdown_tables): + for predicate in predicates: + nodes = nodes_for_predicate(predicate, scope) + + if table not in nodes: + continue + + predicate_condition = None + + for column in predicate.find_all(exp.Column): + if column.table == table: + condition = column.find_ancestor(exp.Condition) + predicate_condition = ( + exp.and_(predicate_condition, condition) + if predicate_condition + else condition + ) + + if predicate_condition: + conditions[table] = ( + exp.or_(conditions[table], predicate_condition) + if table in conditions + else predicate_condition + ) + + for name, node in nodes.items(): + if name not in conditions: + continue + + predicate = conditions[name] + + if isinstance(node, exp.Join): + node.on(predicate, copy=False) + elif isinstance(node, exp.Select): + node.where(replace_aliases(node, predicate), copy=False) + + +def nodes_for_predicate(predicate, sources): + nodes = {} + tables = exp.column_table_names(predicate) + where_condition = isinstance( + predicate.find_ancestor(exp.Join, exp.Where), exp.Where + ) + + for table in tables: + node, source = sources.get(table) or (None, None) + + # if the predicate is in a where statement we can try to push it down + # we want to find the root join or from statement + if node and where_condition: + node = node.find_ancestor(exp.Join, exp.From) + + # a node can reference a CTE which should be push down + if isinstance(node, exp.From) and not isinstance(source, exp.Table): + node = source.expression + + if isinstance(node, exp.Join): + if node.side: + return {} + nodes[table] = node + elif isinstance(node, exp.Select) and len(tables) == 1: + if not node.args.get("group"): + nodes[table] = node + return nodes + + +def replace_aliases(source, predicate): + aliases = {} + + for select in source.selects: + if isinstance(select, exp.Alias): + aliases[select.alias] = select.this + else: + aliases[select.name] = select + + def _replace_alias(column): + if isinstance(column, exp.Column) and column.name in aliases: + return aliases[column.name] + return column + + return predicate.transform(_replace_alias) diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py new file mode 100644 index 0000000..097ce04 --- /dev/null +++ b/sqlglot/optimizer/pushdown_projections.py @@ -0,0 +1,85 @@ +from collections import defaultdict + +from sqlglot import alias, exp +from sqlglot.optimizer.scope import Scope, traverse_scope + +# Sentinel value that means an outer query selecting ALL columns +SELECT_ALL = object() + + +def pushdown_projections(expression): + """ + Rewrite sqlglot AST to remove unused columns projections. + + Example: + >>> import sqlglot + >>> sql = "SELECT y.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x) AS y" + >>> expression = sqlglot.parse_one(sql) + >>> pushdown_projections(expression).sql() + 'SELECT y.a AS a FROM (SELECT x.a AS a FROM x) AS y' + + Args: + expression (sqlglot.Expression): expression to optimize + Returns: + sqlglot.Expression: optimized expression + """ + # Map of Scope to all columns being selected by outer queries. + referenced_columns = defaultdict(set) + + # We build the scope tree (which is traversed in DFS postorder), then iterate + # over the result in reverse order. This should ensure that the set of selected + # columns for a particular scope are completely build by the time we get to it. + for scope in reversed(traverse_scope(expression)): + parent_selections = referenced_columns.get(scope, {SELECT_ALL}) + + if scope.expression.args.get("distinct"): + # We can't remove columns SELECT DISTINCT nor UNION DISTINCT + parent_selections = {SELECT_ALL} + + if isinstance(scope.expression, exp.Union): + left, right = scope.union + referenced_columns[left] = parent_selections + referenced_columns[right] = parent_selections + + if isinstance(scope.expression, exp.Select): + _remove_unused_selections(scope, parent_selections) + + # Group columns by source name + selects = defaultdict(set) + for col in scope.columns: + table_name = col.table + col_name = col.name + selects[table_name].add(col_name) + + # Push the selected columns down to the next scope + for name, (_, source) in scope.selected_sources.items(): + if isinstance(source, Scope): + columns = selects.get(name) or set() + referenced_columns[source].update(columns) + + return expression + + +def _remove_unused_selections(scope, parent_selections): + order = scope.expression.args.get("order") + + if order: + # Assume columns without a qualified table are references to output columns + order_refs = {c.name for c in order.find_all(exp.Column) if not c.table} + else: + order_refs = set() + + new_selections = [] + for selection in scope.selects: + if ( + SELECT_ALL in parent_selections + or selection.alias_or_name in parent_selections + or selection.alias_or_name in order_refs + ): + new_selections.append(selection) + + # If there are no remaining selections, just select a single constant + if not new_selections: + new_selections.append(alias("1", "_")) + + scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py new file mode 100644 index 0000000..394f49e --- /dev/null +++ b/sqlglot/optimizer/qualify_columns.py @@ -0,0 +1,422 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import ensure_schema +from sqlglot.optimizer.scope import traverse_scope + +SKIP_QUALIFY = (exp.Unnest, exp.Lateral) + + +def qualify_columns(expression, schema): + """ + Rewrite sqlglot AST to have fully qualified columns. + + Example: + >>> import sqlglot + >>> schema = {"tbl": {"col": "INT"}} + >>> expression = sqlglot.parse_one("SELECT col FROM tbl") + >>> qualify_columns(expression, schema).sql() + 'SELECT tbl.col AS col FROM tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + schema (dict|sqlglot.optimizer.Schema): Database schema + Returns: + sqlglot.Expression: qualified expression + """ + schema = ensure_schema(schema) + + for scope in traverse_scope(expression): + resolver = _Resolver(scope, schema) + _pop_table_column_aliases(scope.ctes) + _pop_table_column_aliases(scope.derived_tables) + _expand_using(scope, resolver) + _expand_group_by(scope, resolver) + _expand_order_by(scope) + _qualify_columns(scope, resolver) + if not isinstance(scope.expression, SKIP_QUALIFY): + _expand_stars(scope, resolver) + _qualify_outputs(scope) + _check_unknown_tables(scope) + + return expression + + +def _pop_table_column_aliases(derived_tables): + """ + Remove table column aliases. + + (e.g. SELECT ... FROM (SELECT ...) AS foo(col1, col2) + """ + for derived_table in derived_tables: + if isinstance(derived_table, SKIP_QUALIFY): + continue + table_alias = derived_table.args.get("alias") + if table_alias: + table_alias.args.pop("columns", None) + + +def _expand_using(scope, resolver): + joins = list(scope.expression.find_all(exp.Join)) + names = {join.this.alias for join in joins} + ordered = [key for key in scope.selected_sources if key not in names] + + # Mapping of automatically joined column names to source names + column_tables = {} + + for join in joins: + using = join.args.get("using") + + if not using: + continue + + join_table = join.this.alias_or_name + + columns = {} + + for k in scope.selected_sources: + if k in ordered: + for column in resolver.get_source_columns(k): + if column not in columns: + columns[column] = k + + ordered.append(join_table) + join_columns = resolver.get_source_columns(join_table) + conditions = [] + + for identifier in using: + identifier = identifier.name + table = columns.get(identifier) + + if not table or identifier not in join_columns: + raise OptimizeError(f"Cannot automatically join: {identifier}") + + conditions.append( + exp.condition( + exp.EQ( + this=exp.column(identifier, table=table), + expression=exp.column(identifier, table=join_table), + ) + ) + ) + + tables = column_tables.setdefault(identifier, []) + if table not in tables: + tables.append(table) + if join_table not in tables: + tables.append(join_table) + + join.args.pop("using") + join.set("on", exp.and_(*conditions)) + + if column_tables: + for column in scope.columns: + if not column.table and column.name in column_tables: + tables = column_tables[column.name] + coalesce = [exp.column(column.name, table=table) for table in tables] + replacement = exp.Coalesce(this=coalesce[0], expressions=coalesce[1:]) + + # Ensure selects keep their output name + if isinstance(column.parent, exp.Select): + replacement = exp.alias_(replacement, alias=column.name) + + scope.replace(column, replacement) + + +def _expand_group_by(scope, resolver): + group = scope.expression.args.get("group") + if not group: + return + + # Replace references to select aliases + def transform(node, *_): + if isinstance(node, exp.Column) and not node.table: + table = resolver.get_table(node.name) + + # Source columns get priority over select aliases + if table: + node.set("table", exp.to_identifier(table)) + return node + + selects = {s.alias_or_name: s for s in scope.selects} + + select = selects.get(node.name) + if select: + scope.clear_cache() + if isinstance(select, exp.Alias): + select = select.this + return select.copy() + + return node + + group.transform(transform, copy=False) + group.set("expressions", _expand_positional_references(scope, group.expressions)) + scope.expression.set("group", group) + + +def _expand_order_by(scope): + order = scope.expression.args.get("order") + if not order: + return + + ordereds = order.expressions + for ordered, new_expression in zip( + ordereds, + _expand_positional_references(scope, (o.this for o in ordereds)), + ): + ordered.set("this", new_expression) + + +def _expand_positional_references(scope, expressions): + new_nodes = [] + for node in expressions: + if node.is_int: + try: + select = scope.selects[int(node.name) - 1] + except IndexError: + raise OptimizeError(f"Unknown output column: {node.name}") + if isinstance(select, exp.Alias): + select = select.this + new_nodes.append(select.copy()) + scope.clear_cache() + else: + new_nodes.append(node) + + return new_nodes + + +def _qualify_columns(scope, resolver): + """Disambiguate columns, ensuring each column specifies a source""" + for column in scope.columns: + column_table = column.table + column_name = column.name + + if ( + column_table + and column_table in scope.sources + and column_name not in resolver.get_source_columns(column_table) + ): + raise OptimizeError(f"Unknown column: {column_name}") + + if not column_table: + column_table = resolver.get_table(column_name) + + if not scope.is_subquery and not scope.is_unnest: + if column_name not in resolver.all_columns: + raise OptimizeError(f"Unknown column: {column_name}") + + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column_name}") + + # column_table can be a '' because bigquery unnest has no table alias + if column_table: + column.set("table", exp.to_identifier(column_table)) + + +def _expand_stars(scope, resolver): + """Expand stars to lists of column selections""" + + new_selections = [] + except_columns = {} + replace_columns = {} + + for expression in scope.selects: + if isinstance(expression, exp.Star): + tables = list(scope.selected_sources) + _add_except_columns(expression, tables, except_columns) + _add_replace_columns(expression, tables, replace_columns) + elif isinstance(expression, exp.Column) and isinstance( + expression.this, exp.Star + ): + tables = [expression.table] + _add_except_columns(expression.this, tables, except_columns) + _add_replace_columns(expression.this, tables, replace_columns) + else: + new_selections.append(expression) + continue + + for table in tables: + if table not in scope.sources: + raise OptimizeError(f"Unknown table: {table}") + columns = resolver.get_source_columns(table) + table_id = id(table) + for name in columns: + if name not in except_columns.get(table_id, set()): + alias_ = replace_columns.get(table_id, {}).get(name, name) + column = exp.column(name, table) + new_selections.append( + alias(column, alias_) if alias_ != name else column + ) + + scope.expression.set("expressions", new_selections) + + +def _add_except_columns(expression, tables, except_columns): + except_ = expression.args.get("except") + + if not except_: + return + + columns = {e.name for e in except_} + + for table in tables: + except_columns[id(table)] = columns + + +def _add_replace_columns(expression, tables, replace_columns): + replace = expression.args.get("replace") + + if not replace: + return + + columns = {e.this.name: e.alias for e in replace} + + for table in tables: + replace_columns[id(table)] = columns + + +def _qualify_outputs(scope): + """Ensure all output columns are aliased""" + new_selections = [] + + for i, (selection, aliased_column) in enumerate( + itertools.zip_longest(scope.selects, scope.outer_column_list) + ): + if isinstance(selection, exp.Column): + # convoluted setter because a simple selection.replace(alias) would require a copy + alias_ = alias(exp.column(""), alias=selection.name) + alias_.set("this", selection) + selection = alias_ + elif not isinstance(selection, exp.Alias): + alias_ = alias(exp.column(""), f"_col_{i}") + alias_.set("this", selection) + selection = alias_ + + if aliased_column: + selection.set("alias", exp.to_identifier(aliased_column)) + + new_selections.append(selection) + + scope.expression.set("expressions", new_selections) + + +def _check_unknown_tables(scope): + if ( + scope.external_columns + and not scope.is_unnest + and not scope.is_correlated_subquery + ): + raise OptimizeError(f"Unknown table: {scope.external_columns[0].text('table')}") + + +class _Resolver: + """ + Helper for resolving columns. + + This is a class so we can lazily load some things and easily share them across functions. + """ + + def __init__(self, scope, schema): + self.scope = scope + self.schema = schema + self._source_columns = None + self._unambiguous_columns = None + self._all_columns = None + + def get_table(self, column_name): + """ + Get the table for a column name. + + Args: + column_name (str) + Returns: + (str) table name + """ + if self._unambiguous_columns is None: + self._unambiguous_columns = self._get_unambiguous_columns( + self._get_all_source_columns() + ) + return self._unambiguous_columns.get(column_name) + + @property + def all_columns(self): + """All available columns of all sources in this scope""" + if self._all_columns is None: + self._all_columns = set( + column + for columns in self._get_all_source_columns().values() + for column in columns + ) + return self._all_columns + + def get_source_columns(self, name): + """Resolve the source columns for a given source `name`""" + if name not in self.scope.sources: + raise OptimizeError(f"Unknown table: {name}") + + source = self.scope.sources[name] + + # If referencing a table, return the columns from the schema + if isinstance(source, exp.Table): + try: + return self.schema.column_names(source) + except Exception as e: + raise OptimizeError(str(e)) from e + + # Otherwise, if referencing another scope, return that scope's named selects + return source.expression.named_selects + + def _get_all_source_columns(self): + if self._source_columns is None: + self._source_columns = { + k: self.get_source_columns(k) for k in self.scope.selected_sources + } + return self._source_columns + + def _get_unambiguous_columns(self, source_columns): + """ + Find all the unambiguous columns in sources. + + Args: + source_columns (dict): Mapping of names to source columns + Returns: + dict: Mapping of column name to source name + """ + if not source_columns: + return {} + + source_columns = list(source_columns.items()) + + first_table, first_columns = source_columns[0] + unambiguous_columns = { + col: first_table for col in self._find_unique_columns(first_columns) + } + all_columns = set(unambiguous_columns) + + for table, columns in source_columns[1:]: + unique = self._find_unique_columns(columns) + ambiguous = set(all_columns).intersection(unique) + all_columns.update(columns) + for column in ambiguous: + unambiguous_columns.pop(column, None) + for column in unique.difference(ambiguous): + unambiguous_columns[column] = table + + return unambiguous_columns + + @staticmethod + def _find_unique_columns(columns): + """ + Find the unique columns in a list of columns. + + Example: + >>> sorted(_Resolver._find_unique_columns(["a", "b", "b", "c"])) + ['a', 'c'] + + This is necessary because duplicate column names are ambiguous. + """ + counts = {} + for column in columns: + counts[column] = counts.get(column, 0) + 1 + return {column for column, count in counts.items() if count == 1} diff --git a/sqlglot/optimizer/qualify_tables.py b/sqlglot/optimizer/qualify_tables.py new file mode 100644 index 0000000..9f8b9f5 --- /dev/null +++ b/sqlglot/optimizer/qualify_tables.py @@ -0,0 +1,54 @@ +import itertools + +from sqlglot import alias, exp +from sqlglot.optimizer.scope import traverse_scope + + +def qualify_tables(expression, db=None, catalog=None): + """ + Rewrite sqlglot AST to have fully qualified tables. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT 1 FROM tbl") + >>> qualify_tables(expression, db="db").sql() + 'SELECT 1 FROM db.tbl AS tbl' + + Args: + expression (sqlglot.Expression): expression to qualify + db (str): Database name + catalog (str): Catalog name + Returns: + sqlglot.Expression: qualified expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + for derived_table in scope.ctes + scope.derived_tables: + if not derived_table.args.get("alias"): + alias_ = f"_q_{next(sequence)}" + derived_table.set( + "alias", exp.TableAlias(this=exp.to_identifier(alias_)) + ) + scope.rename_source(None, alias_) + + for source in scope.sources.values(): + if isinstance(source, exp.Table): + identifier = isinstance(source.this, exp.Identifier) + + if identifier: + if not source.args.get("db"): + source.set("db", exp.to_identifier(db)) + if not source.args.get("catalog"): + source.set("catalog", exp.to_identifier(catalog)) + + if not isinstance(source.parent, exp.Alias): + source.replace( + alias( + source.copy(), + source.this if identifier else f"_q_{next(sequence)}", + table=True, + ) + ) + + return expression diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py new file mode 100644 index 0000000..17623cc --- /dev/null +++ b/sqlglot/optimizer/quote_identities.py @@ -0,0 +1,25 @@ +from sqlglot import exp + + +def quote_identities(expression): + """ + Rewrite sqlglot AST to ensure all identities are quoted. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x") + >>> quote_identities(expression).sql() + 'SELECT "x"."a" AS "a" FROM "db"."x"' + + Args: + expression (sqlglot.Expression): expression to quote + Returns: + sqlglot.Expression: quoted expression + """ + + def qualify(node): + if isinstance(node, exp.Identifier): + node.set("quoted", True) + return node + + return expression.transform(qualify, copy=False) diff --git a/sqlglot/optimizer/schema.py b/sqlglot/optimizer/schema.py new file mode 100644 index 0000000..9968108 --- /dev/null +++ b/sqlglot/optimizer/schema.py @@ -0,0 +1,129 @@ +import abc + +from sqlglot import exp +from sqlglot.errors import OptimizeError +from sqlglot.helper import csv_reader + + +class Schema(abc.ABC): + """Abstract base class for database schemas""" + + @abc.abstractmethod + def column_names(self, table): + """ + Get the column names for a table. + + Args: + table (sqlglot.expressions.Table): Table expression instance + Returns: + list[str]: list of column names + """ + + +class MappingSchema(Schema): + """ + Schema based on a nested mapping. + + Args: + schema (dict): Mapping in one of the following forms: + 1. {table: {col: type}} + 2. {db: {table: {col: type}}} + 3. {catalog: {db: {table: {col: type}}}} + """ + + def __init__(self, schema): + self.schema = schema + + depth = _dict_depth(schema) + + if not depth: # {} + self.supported_table_args = [] + elif depth == 2: # {table: {col: type}} + self.supported_table_args = ("this",) + elif depth == 3: # {db: {table: {col: type}}} + self.supported_table_args = ("db", "this") + elif depth == 4: # {catalog: {db: {table: {col: type}}}} + self.supported_table_args = ("catalog", "db", "this") + else: + raise OptimizeError(f"Invalid schema shape. Depth: {depth}") + + self.forbidden_args = {"catalog", "db", "this"} - set(self.supported_table_args) + + def column_names(self, table): + if not isinstance(table.this, exp.Identifier): + return fs_get(table) + + args = tuple(table.text(p) for p in self.supported_table_args) + + for forbidden in self.forbidden_args: + if table.text(forbidden): + raise ValueError( + f"Schema doesn't support {forbidden}. Received: {table.sql()}" + ) + return list(_nested_get(self.schema, *zip(self.supported_table_args, args))) + + +def ensure_schema(schema): + if isinstance(schema, Schema): + return schema + + return MappingSchema(schema) + + +def fs_get(table): + name = table.this.name.upper() + + if name.upper() == "READ_CSV": + with csv_reader(table) as reader: + return next(reader) + + raise ValueError(f"Cannot read schema for {table}") + + +def _nested_get(d, *path): + """ + Get a value for a nested dictionary. + + Args: + d (dict): dictionary + *path (tuple[str, str]): tuples of (name, key) + `key` is the key in the dictionary to get. + `name` is a string to use in the error if `key` isn't found. + """ + for name, key in path: + d = d.get(key) + if d is None: + name = "table" if name == "this" else name + raise ValueError(f"Unknown {name}") + return d + + +def _dict_depth(d): + """ + Get the nesting depth of a dictionary. + + For example: + >>> _dict_depth(None) + 0 + >>> _dict_depth({}) + 1 + >>> _dict_depth({"a": "b"}) + 1 + >>> _dict_depth({"a": {}}) + 2 + >>> _dict_depth({"a": {"b": {}}}) + 3 + + Args: + d (dict): dictionary + Returns: + int: depth + """ + try: + return 1 + _dict_depth(next(iter(d.values()))) + except AttributeError: + # d doesn't have attribute "values" + return 0 + except StopIteration: + # d.values() returns an empty sequence + return 1 diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py new file mode 100644 index 0000000..f6f59e8 --- /dev/null +++ b/sqlglot/optimizer/scope.py @@ -0,0 +1,438 @@ +from copy import copy +from enum import Enum, auto + +from sqlglot import exp +from sqlglot.errors import OptimizeError + + +class ScopeType(Enum): + ROOT = auto() + SUBQUERY = auto() + DERIVED_TABLE = auto() + CTE = auto() + UNION = auto() + UNNEST = auto() + + +class Scope: + """ + Selection scope. + + Attributes: + expression (exp.Select|exp.Union): Root expression of this scope + sources (dict[str, exp.Table|Scope]): Mapping of source name to either + a Table expression or another Scope instance. For example: + SELECT * FROM x {"x": Table(this="x")} + SELECT * FROM x AS y {"y": Table(this="x")} + SELECT * FROM (SELECT ...) AS y {"y": Scope(...)} + outer_column_list (list[str]): If this is a derived table or CTE, and the outer query + defines a column list of it's alias of this scope, this is that list of columns. + For example: + SELECT * FROM (SELECT ...) AS y(col1, col2) + The inner query would have `["col1", "col2"]` for its `outer_column_list` + parent (Scope): Parent scope + scope_type (ScopeType): Type of this scope, relative to it's parent + subquery_scopes (list[Scope]): List of all child scopes for subqueries. + This does not include derived tables or CTEs. + union (tuple[Scope, Scope]): If this Scope is for a Union expression, this will be + a tuple of the left and right child scopes. + """ + + def __init__( + self, + expression, + sources=None, + outer_column_list=None, + parent=None, + scope_type=ScopeType.ROOT, + ): + self.expression = expression + self.sources = sources or {} + self.outer_column_list = outer_column_list or [] + self.parent = parent + self.scope_type = scope_type + self.subquery_scopes = [] + self.union = None + self.clear_cache() + + def clear_cache(self): + self._collected = False + self._raw_columns = None + self._derived_tables = None + self._tables = None + self._ctes = None + self._subqueries = None + self._selected_sources = None + self._columns = None + self._external_columns = None + + def branch(self, expression, scope_type, add_sources=None, **kwargs): + """Branch from the current scope to a new, inner scope""" + sources = copy(self.sources) + if add_sources: + sources.update(add_sources) + return Scope( + expression=expression.unnest(), + sources=sources, + parent=self, + scope_type=scope_type, + **kwargs, + ) + + def _collect(self): + self._tables = [] + self._ctes = [] + self._subqueries = [] + self._derived_tables = [] + self._raw_columns = [] + + # We'll use this variable to pass state into the dfs generator. + # Whenever we set it to True, we exclude a subtree from traversal. + prune = False + + for node, parent, _ in self.expression.dfs(prune=lambda *_: prune): + prune = False + + if node is self.expression: + continue + if isinstance(node, exp.Column) and not isinstance(node.this, exp.Star): + self._raw_columns.append(node) + elif isinstance(node, exp.Table): + self._tables.append(node) + elif isinstance(node, (exp.Unnest, exp.Lateral)): + self._derived_tables.append(node) + elif isinstance(node, exp.CTE): + self._ctes.append(node) + prune = True + elif isinstance(node, exp.Subquery) and isinstance( + parent, (exp.From, exp.Join) + ): + self._derived_tables.append(node) + prune = True + elif isinstance(node, exp.Subqueryable): + self._subqueries.append(node) + prune = True + + self._collected = True + + def _ensure_collected(self): + if not self._collected: + self._collect() + + def replace(self, old, new): + """ + Replace `old` with `new`. + + This can be used instead of `exp.Expression.replace` to ensure the `Scope` is kept up-to-date. + + Args: + old (exp.Expression): old node + new (exp.Expression): new node + """ + old.replace(new) + self.clear_cache() + + @property + def tables(self): + """ + List of tables in this scope. + + Returns: + list[exp.Table]: tables + """ + self._ensure_collected() + return self._tables + + @property + def ctes(self): + """ + List of CTEs in this scope. + + Returns: + list[exp.CTE]: ctes + """ + self._ensure_collected() + return self._ctes + + @property + def derived_tables(self): + """ + List of derived tables in this scope. + + For example: + SELECT * FROM (SELECT ...) <- that's a derived table + + Returns: + list[exp.Subquery]: derived tables + """ + self._ensure_collected() + return self._derived_tables + + @property + def subqueries(self): + """ + List of subqueries in this scope. + + For example: + SELECT * FROM x WHERE a IN (SELECT ...) <- that's a subquery + + Returns: + list[exp.Subqueryable]: subqueries + """ + self._ensure_collected() + return self._subqueries + + @property + def columns(self): + """ + List of columns in this scope. + + Returns: + list[exp.Column]: Column instances in this scope, plus any + Columns that reference this scope from correlated subqueries. + """ + if self._columns is None: + self._ensure_collected() + columns = self._raw_columns + + external_columns = [ + column + for scope in self.subquery_scopes + for column in scope.external_columns + ] + + named_outputs = {e.alias_or_name for e in self.expression.expressions} + + self._columns = [ + c + for c in columns + external_columns + if not ( + c.find_ancestor(exp.Qualify, exp.Order) and c.name in named_outputs + ) + ] + return self._columns + + @property + def selected_sources(self): + """ + Mapping of nodes and sources that are actually selected from in this scope. + + That is, all tables in a schema are selectable at any point. But a + table only becomes a selected source if it's included in a FROM or JOIN clause. + + Returns: + dict[str, (exp.Table|exp.Subquery, exp.Table|Scope)]: selected sources and nodes + """ + if self._selected_sources is None: + referenced_names = [] + + for table in self.tables: + referenced_names.append( + ( + table.parent.alias + if isinstance(table.parent, exp.Alias) + else table.name, + table, + ) + ) + for derived_table in self.derived_tables: + referenced_names.append((derived_table.alias, derived_table.unnest())) + + result = {} + + for name, node in referenced_names: + if name in self.sources: + result[name] = (node, self.sources[name]) + + self._selected_sources = result + return self._selected_sources + + @property + def selects(self): + """ + Select expressions of this scope. + + For example, for the following expression: + SELECT 1 as a, 2 as b FROM x + + The outputs are the "1 as a" and "2 as b" expressions. + + Returns: + list[exp.Expression]: expressions + """ + if isinstance(self.expression, exp.Union): + return [] + return self.expression.selects + + @property + def external_columns(self): + """ + Columns that appear to reference sources in outer scopes. + + Returns: + list[exp.Column]: Column instances that don't reference + sources in the current scope. + """ + if self._external_columns is None: + self._external_columns = [ + c for c in self.columns if c.table not in self.selected_sources + ] + return self._external_columns + + def source_columns(self, source_name): + """ + Get all columns in the current scope for a particular source. + + Args: + source_name (str): Name of the source + Returns: + list[exp.Column]: Column instances that reference `source_name` + """ + return [column for column in self.columns if column.table == source_name] + + @property + def is_subquery(self): + """Determine if this scope is a subquery""" + return self.scope_type == ScopeType.SUBQUERY + + @property + def is_unnest(self): + """Determine if this scope is an unnest""" + return self.scope_type == ScopeType.UNNEST + + @property + def is_correlated_subquery(self): + """Determine if this scope is a correlated subquery""" + return bool(self.is_subquery and self.external_columns) + + def rename_source(self, old_name, new_name): + """Rename a source in this scope""" + columns = self.sources.pop(old_name or "", []) + self.sources[new_name] = columns + + +def traverse_scope(expression): + """ + Traverse an expression by it's "scopes". + + "Scope" represents the current context of a Select statement. + + This is helpful for optimizing queries, where we need more information than + the expression tree itself. For example, we might care about the source + names within a subquery. Returns a list because a generator could result in + incomplete properties which is confusing. + + Examples: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT a FROM (SELECT a FROM x) AS y") + >>> scopes = traverse_scope(expression) + >>> scopes[0].expression.sql(), list(scopes[0].sources) + ('SELECT a FROM x', ['x']) + >>> scopes[1].expression.sql(), list(scopes[1].sources) + ('SELECT a FROM (SELECT a FROM x) AS y', ['y']) + + Args: + expression (exp.Expression): expression to traverse + Returns: + List[Scope]: scope instances + """ + return list(_traverse_scope(Scope(expression))) + + +def _traverse_scope(scope): + if isinstance(scope.expression, exp.Select): + yield from _traverse_select(scope) + elif isinstance(scope.expression, exp.Union): + yield from _traverse_union(scope) + elif isinstance(scope.expression, (exp.Lateral, exp.Unnest)): + pass + elif isinstance(scope.expression, exp.Subquery): + yield from _traverse_subqueries(scope) + else: + raise OptimizeError(f"Unexpected expression type: {type(scope.expression)}") + yield scope + + +def _traverse_select(scope): + yield from _traverse_derived_tables(scope.ctes, scope, ScopeType.CTE) + yield from _traverse_subqueries(scope) + yield from _traverse_derived_tables( + scope.derived_tables, scope, ScopeType.DERIVED_TABLE + ) + _add_table_sources(scope) + + +def _traverse_union(scope): + yield from _traverse_derived_tables(scope.ctes, scope, scope_type=ScopeType.CTE) + + # The last scope to be yield should be the top most scope + left = None + for left in _traverse_scope( + scope.branch(scope.expression.left, scope_type=ScopeType.UNION) + ): + yield left + + right = None + for right in _traverse_scope( + scope.branch(scope.expression.right, scope_type=ScopeType.UNION) + ): + yield right + + scope.union = (left, right) + + +def _traverse_derived_tables(derived_tables, scope, scope_type): + sources = {} + + for derived_table in derived_tables: + for child_scope in _traverse_scope( + scope.branch( + derived_table + if isinstance(derived_table, (exp.Unnest, exp.Lateral)) + else derived_table.this, + add_sources=sources if scope_type == ScopeType.CTE else None, + outer_column_list=derived_table.alias_column_names, + scope_type=ScopeType.UNNEST + if isinstance(derived_table, exp.Unnest) + else scope_type, + ) + ): + yield child_scope + # Tables without aliases will be set as "" + # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. + # Until then, this means that only a single, unaliased derived table is allowed (rather, + # the latest one wins. + sources[derived_table.alias] = child_scope + scope.sources.update(sources) + + +def _add_table_sources(scope): + sources = {} + for table in scope.tables: + table_name = table.name + + if isinstance(table.parent, exp.Alias): + source_name = table.parent.alias + else: + source_name = table_name + + if table_name in scope.sources: + # This is a reference to a parent source (e.g. a CTE), not an actual table. + scope.sources[source_name] = scope.sources[table_name] + elif source_name in scope.sources: + raise OptimizeError(f"Duplicate table name: {source_name}") + else: + sources[source_name] = table + + scope.sources.update(sources) + + +def _traverse_subqueries(scope): + for subquery in scope.subqueries: + top = None + for child_scope in _traverse_scope( + scope.branch(subquery, scope_type=ScopeType.SUBQUERY) + ): + yield child_scope + top = child_scope + scope.subquery_scopes.append(top) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py new file mode 100644 index 0000000..6771153 --- /dev/null +++ b/sqlglot/optimizer/simplify.py @@ -0,0 +1,383 @@ +import datetime +import functools +import itertools +from collections import deque +from decimal import Decimal + +from sqlglot import exp +from sqlglot.expressions import FALSE, NULL, TRUE +from sqlglot.generator import Generator +from sqlglot.helper import while_changing + +GENERATOR = Generator(normalize=True, identify=True) + + +def simplify(expression): + """ + Rewrite sqlglot AST to simplify expressions. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("TRUE AND TRUE") + >>> simplify(expression).sql() + 'TRUE' + + Args: + expression (sqlglot.Expression): expression to simplify + Returns: + sqlglot.Expression: simplified expression + """ + + def _simplify(expression, root=True): + node = expression + node = uniq_sort(node) + node = absorb_and_eliminate(node) + exp.replace_children(node, lambda e: _simplify(e, False)) + node = simplify_not(node) + node = flatten(node) + node = simplify_connectors(node) + node = remove_compliments(node) + node.parent = expression.parent + node = simplify_literals(node) + node = simplify_parens(node) + if root: + expression.replace(node) + return node + + expression = while_changing(expression, _simplify) + remove_where_true(expression) + return expression + + +def simplify_not(expression): + """ + Demorgan's Law + NOT (x OR y) -> NOT x AND NOT y + NOT (x AND y) -> NOT x OR NOT y + """ + if isinstance(expression, exp.Not): + if isinstance(expression.this, exp.Paren): + condition = expression.this.unnest() + if isinstance(condition, exp.And): + return exp.or_(exp.not_(condition.left), exp.not_(condition.right)) + if isinstance(condition, exp.Or): + return exp.and_(exp.not_(condition.left), exp.not_(condition.right)) + if always_true(expression.this): + return FALSE + if expression.this == FALSE: + return TRUE + if isinstance(expression.this, exp.Not): + # double negation + # NOT NOT x -> x + return expression.this.this + return expression + + +def flatten(expression): + """ + A AND (B AND C) -> A AND B AND C + A OR (B OR C) -> A OR B OR C + """ + if isinstance(expression, exp.Connector): + for node in expression.args.values(): + child = node.unnest() + if isinstance(child, expression.__class__): + node.replace(child) + return expression + + +def simplify_connectors(expression): + if isinstance(expression, exp.Connector): + left = expression.left + right = expression.right + + if left == right: + return left + + if isinstance(expression, exp.And): + if NULL in (left, right): + return NULL + if FALSE in (left, right): + return FALSE + if always_true(left) and always_true(right): + return TRUE + if always_true(left): + return right + if always_true(right): + return left + elif isinstance(expression, exp.Or): + if always_true(left) or always_true(right): + return TRUE + if left == FALSE and right == FALSE: + return FALSE + if ( + (left == NULL and right == NULL) + or (left == NULL and right == FALSE) + or (left == FALSE and right == NULL) + ): + return NULL + if left == FALSE: + return right + if right == FALSE: + return left + return expression + + +def remove_compliments(expression): + """ + Removing compliments. + + A AND NOT A -> FALSE + A OR NOT A -> TRUE + """ + if isinstance(expression, exp.Connector): + compliment = FALSE if isinstance(expression, exp.And) else TRUE + + for a, b in itertools.permutations(expression.flatten(), 2): + if is_complement(a, b): + return compliment + return expression + + +def uniq_sort(expression): + """ + Uniq and sort a connector. + + C AND A AND B AND B -> A AND B AND C + """ + if isinstance(expression, exp.Connector): + result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ + flattened = tuple(expression.flatten()) + deduped = {GENERATOR.generate(e): e for e in flattened} + arr = tuple(deduped.items()) + + # check if the operands are already sorted, if not sort them + # A AND C AND B -> A AND B AND C + for i, (sql, e) in enumerate(arr[1:]): + if sql < arr[i][0]: + expression = result_func(*(deduped[sql] for sql in sorted(deduped))) + break + else: + # we didn't have to sort but maybe we need to dedup + if len(deduped) < len(flattened): + expression = result_func(*deduped.values()) + + return expression + + +def absorb_and_eliminate(expression): + """ + absorption: + A AND (A OR B) -> A + A OR (A AND B) -> A + A AND (NOT A OR B) -> A AND B + A OR (NOT A AND B) -> A OR B + elimination: + (A AND B) OR (A AND NOT B) -> A + (A OR B) AND (A OR NOT B) -> A + """ + if isinstance(expression, exp.Connector): + kind = exp.Or if isinstance(expression, exp.And) else exp.And + + for a, b in itertools.permutations(expression.flatten(), 2): + if isinstance(a, kind): + aa, ab = a.unnest_operands() + + # absorb + if is_complement(b, aa): + aa.replace(exp.TRUE if kind == exp.And else exp.FALSE) + elif is_complement(b, ab): + ab.replace(exp.TRUE if kind == exp.And else exp.FALSE) + elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set( + a.flatten() + ): + a.replace(exp.FALSE if kind == exp.And else exp.TRUE) + elif isinstance(b, kind): + # eliminate + rhs = b.unnest_operands() + ba, bb = rhs + + if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): + a.replace(aa) + b.replace(aa) + elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): + a.replace(ab) + b.replace(ab) + + return expression + + +def simplify_literals(expression): + if isinstance(expression, exp.Binary): + operands = [] + queue = deque(expression.flatten(unnest=False)) + size = len(queue) + + while queue: + a = queue.popleft() + + for b in queue: + result = _simplify_binary(expression, a, b) + + if result: + queue.remove(b) + queue.append(result) + break + else: + operands.append(a) + + if len(operands) < size: + return functools.reduce( + lambda a, b: expression.__class__(this=a, expression=b), operands + ) + elif isinstance(expression, exp.Neg): + this = expression.this + if this.is_number: + value = this.name + if value[0] == "-": + return exp.Literal.number(value[1:]) + return exp.Literal.number(f"-{value}") + + return expression + + +def _simplify_binary(expression, a, b): + if isinstance(expression, exp.Is): + if isinstance(b, exp.Not): + c = b.this + not_ = True + else: + c = b + not_ = False + + if c == NULL: + if isinstance(a, exp.Literal): + return TRUE if not_ else FALSE + if a == NULL: + return FALSE if not_ else TRUE + elif NULL in (a, b): + return NULL + + if isinstance(expression, exp.EQ) and a == b: + return TRUE + + if a.is_number and b.is_number: + a = int(a.name) if a.is_int else Decimal(a.name) + b = int(b.name) if b.is_int else Decimal(b.name) + + if isinstance(expression, exp.Add): + return exp.Literal.number(a + b) + if isinstance(expression, exp.Sub): + return exp.Literal.number(a - b) + if isinstance(expression, exp.Mul): + return exp.Literal.number(a * b) + if isinstance(expression, exp.Div): + if isinstance(a, int) and isinstance(b, int): + return exp.Literal.number(a // b) + return exp.Literal.number(a / b) + + boolean = eval_boolean(expression, a, b) + + if boolean: + return boolean + elif a.is_string and b.is_string: + boolean = eval_boolean(expression, a, b) + + if boolean: + return boolean + elif isinstance(a, exp.Cast) and isinstance(b, exp.Interval): + a, b = extract_date(a), extract_interval(b) + if b: + if isinstance(expression, exp.Add): + return date_literal(a + b) + if isinstance(expression, exp.Sub): + return date_literal(a - b) + elif isinstance(a, exp.Interval) and isinstance(b, exp.Cast): + a, b = extract_interval(a), extract_date(b) + # you cannot subtract a date from an interval + if a and isinstance(expression, exp.Add): + return date_literal(a + b) + + return None + + +def simplify_parens(expression): + if ( + isinstance(expression, exp.Paren) + and not isinstance(expression.this, exp.Select) + and ( + not isinstance(expression.parent, (exp.Condition, exp.Binary)) + or isinstance(expression.this, (exp.Is, exp.Like)) + or not isinstance(expression.this, exp.Binary) + ) + ): + return expression.this + return expression + + +def remove_where_true(expression): + for where in expression.find_all(exp.Where): + if always_true(where.this): + where.parent.set("where", None) + for join in expression.find_all(exp.Join): + if always_true(join.args.get("on")): + join.set("kind", "CROSS") + join.set("on", None) + + +def always_true(expression): + return expression == TRUE or isinstance(expression, exp.Literal) + + +def is_complement(a, b): + return isinstance(b, exp.Not) and b.this == a + + +def eval_boolean(expression, a, b): + if isinstance(expression, (exp.EQ, exp.Is)): + return boolean_literal(a == b) + if isinstance(expression, exp.NEQ): + return boolean_literal(a != b) + if isinstance(expression, exp.GT): + return boolean_literal(a > b) + if isinstance(expression, exp.GTE): + return boolean_literal(a >= b) + if isinstance(expression, exp.LT): + return boolean_literal(a < b) + if isinstance(expression, exp.LTE): + return boolean_literal(a <= b) + return None + + +def extract_date(cast): + if cast.args["to"].this == exp.DataType.Type.DATE: + return datetime.date.fromisoformat(cast.name) + return None + + +def extract_interval(interval): + try: + from dateutil.relativedelta import relativedelta + except ModuleNotFoundError: + return None + + n = int(interval.name) + unit = interval.text("unit").lower() + + if unit == "year": + return relativedelta(years=n) + if unit == "month": + return relativedelta(months=n) + if unit == "week": + return relativedelta(weeks=n) + if unit == "day": + return relativedelta(days=n) + return None + + +def date_literal(date): + return exp.Cast(this=exp.Literal.string(date), to=exp.DataType.build("DATE")) + + +def boolean_literal(condition): + return TRUE if condition else FALSE diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py new file mode 100644 index 0000000..55c81c5 --- /dev/null +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -0,0 +1,220 @@ +import itertools + +from sqlglot import exp +from sqlglot.optimizer.scope import traverse_scope + + +def unnest_subqueries(expression): + """ + Rewrite sqlglot AST to convert some predicates with subqueries into joins. + + Convert the subquery into a group by so it is not a many to many left join. + Unnesting can only occur if the subquery does not have LIMIT or OFFSET. + Unnesting non correlated subqueries only happens on IN statements or = ANY statements. + + Example: + >>> import sqlglot + >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") + >>> unnest_subqueries(expression).sql() + 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ + AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + + Args: + expression (sqlglot.Expression): expression to unnest + Returns: + sqlglot.Expression: unnested expression + """ + sequence = itertools.count() + + for scope in traverse_scope(expression): + select = scope.expression + parent = select.parent_select + if scope.external_columns: + decorrelate(select, parent, scope.external_columns, sequence) + else: + unnest(select, parent, sequence) + + return expression + + +def unnest(select, parent_select, sequence): + predicate = select.find_ancestor(exp.In, exp.Any) + + if not predicate or parent_select is not predicate.parent_select: + return + + if len(select.selects) > 1 or select.find(exp.Limit, exp.Offset): + return + + if isinstance(predicate, exp.Any): + predicate = predicate.find_ancestor(exp.EQ) + + if not predicate or parent_select is not predicate.parent_select: + return + + column = _other_operand(predicate) + value = select.selects[0] + alias = _alias(sequence) + + on = exp.condition(f'{column} = "{alias}"."{value.alias}"') + _replace(predicate, f"NOT {on.right} IS NULL") + + parent_select.join( + select.group_by(value.this, copy=False), + on=on, + join_type="LEFT", + join_alias=alias, + copy=False, + ) + + +def decorrelate(select, parent_select, external_columns, sequence): + where = select.args.get("where") + + if not where or where.find(exp.Or) or select.find(exp.Limit, exp.Offset): + return + + table_alias = _alias(sequence) + keys = [] + + # for all external columns in the where statement, + # split out the relevant data to convert it into a join + for column in external_columns: + if column.find_ancestor(exp.Where) is not where: + return + + predicate = column.find_ancestor(exp.Predicate) + + if not predicate or predicate.find_ancestor(exp.Where) is not where: + return + + if isinstance(predicate, exp.Binary): + key = ( + predicate.right + if any(node is column for node, *_ in predicate.left.walk()) + else predicate.left + ) + else: + return + + keys.append((key, column, predicate)) + + if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): + return + + value = select.selects[0] + key_aliases = {} + group_by = [] + + for key, _, predicate in keys: + # if we filter on the value of the subquery, it needs to be unique + if key == value.this: + key_aliases[key] = value.alias + group_by.append(key) + else: + if key not in key_aliases: + key_aliases[key] = _alias(sequence) + # all predicates that are equalities must also be in the unique + # so that we don't do a many to many join + if isinstance(predicate, exp.EQ) and key not in group_by: + group_by.append(key) + + parent_predicate = select.find_ancestor(exp.Predicate) + + # if the value of the subquery is not an agg or a key, we need to collect it into an array + # so that it can be grouped + if not value.find(exp.AggFunc) and value.this not in group_by: + select.select( + f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False + ) + + # exists queries should not have any selects as it only checks if there are any rows + # all selects will be added by the optimizer and only used for join keys + if isinstance(parent_predicate, exp.Exists): + select.args["expressions"] = [] + + for key, alias in key_aliases.items(): + if key in group_by: + # add all keys to the projections of the subquery + # so that we can use it as a join key + if isinstance(parent_predicate, exp.Exists) or key != value.this: + select.select(f"{key} AS {alias}", copy=False) + else: + select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + + alias = exp.column(value.alias, table_alias) + other = _other_operand(parent_predicate) + + if isinstance(parent_predicate, exp.Exists): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL") + else: + parent_predicate = _replace(parent_predicate, "TRUE") + elif isinstance(parent_predicate, exp.All): + parent_predicate = _replace( + parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.Any): + if value.this in group_by: + parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})" + ) + elif isinstance(parent_predicate, exp.In): + if value.this in group_by: + parent_predicate = _replace(parent_predicate, f"{other} = {alias}") + else: + parent_predicate = _replace( + parent_predicate, + f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", + ) + else: + select.parent.replace(alias) + + for key, column, predicate in keys: + predicate.replace(exp.TRUE) + nested = exp.column(key_aliases[key], table_alias) + + if key in group_by: + key.replace(nested) + parent_predicate = _replace( + parent_predicate, f"({parent_predicate} AND NOT {nested} IS NULL)" + ) + elif isinstance(predicate, exp.EQ): + parent_predicate = _replace( + parent_predicate, + f"({parent_predicate} AND ARRAY_CONTAINS({nested}, {column}))", + ) + else: + key.replace(exp.to_identifier("_x")) + parent_predicate = _replace( + parent_predicate, + f'({parent_predicate} AND ARRAY_ANY({nested}, "_x" -> {predicate}))', + ) + + parent_select.join( + select.group_by(*group_by, copy=False), + on=[predicate for *_, predicate in keys if isinstance(predicate, exp.EQ)], + join_type="LEFT", + join_alias=table_alias, + copy=False, + ) + + +def _alias(sequence): + return f"_u_{next(sequence)}" + + +def _replace(expression, condition): + return expression.replace(exp.condition(condition)) + + +def _other_operand(expression): + if isinstance(expression, exp.In): + return expression.this + + if isinstance(expression, exp.Binary): + return expression.right if expression.arg_key == "this" else expression.left + + return None diff --git a/sqlglot/parser.py b/sqlglot/parser.py new file mode 100644 index 0000000..9396c50 --- /dev/null +++ b/sqlglot/parser.py @@ -0,0 +1,2190 @@ +import logging + +from sqlglot import exp +from sqlglot.errors import ErrorLevel, ParseError, concat_errors +from sqlglot.helper import apply_index_offset, ensure_list, list_get +from sqlglot.tokens import Token, Tokenizer, TokenType + +logger = logging.getLogger("sqlglot") + + +class Parser: + """ + Parser consumes a list of tokens produced by the :class:`~sqlglot.tokens.Tokenizer` + and produces a parsed syntax tree. + + Args + error_level (ErrorLevel): the desired error level. Default: ErrorLevel.RAISE. + error_message_context (int): determines the amount of context to capture from + a query string when displaying the error message (in number of characters). + Default: 50. + index_offset (int): Index offset for arrays eg ARRAY[0] vs ARRAY[1] as the head of a list + Default: 0 + alias_post_tablesample (bool): If the table alias comes after tablesample + Default: False + max_errors (int): Maximum number of error messages to include in a raised ParseError. + This is only relevant if error_level is ErrorLevel.RAISE. + Default: 3 + null_ordering (str): Indicates the default null ordering method to use if not explicitly set. + Options are "nulls_are_small", "nulls_are_large", "nulls_are_last". + Default: "nulls_are_small" + """ + + FUNCTIONS = { + **{name: f.from_arg_list for f in exp.ALL_FUNCTIONS for name in f.sql_names()}, + "DATE_TO_DATE_STR": lambda args: exp.Cast( + this=list_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + "TIME_TO_TIME_STR": lambda args: exp.Cast( + this=list_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + "TS_OR_DS_TO_DATE_STR": lambda args: exp.Substring( + this=exp.Cast( + this=list_get(args, 0), + to=exp.DataType(this=exp.DataType.Type.TEXT), + ), + start=exp.Literal.number(1), + length=exp.Literal.number(10), + ), + } + + NO_PAREN_FUNCTIONS = { + TokenType.CURRENT_DATE: exp.CurrentDate, + TokenType.CURRENT_DATETIME: exp.CurrentDate, + TokenType.CURRENT_TIMESTAMP: exp.CurrentTimestamp, + } + + NESTED_TYPE_TOKENS = { + TokenType.ARRAY, + TokenType.MAP, + TokenType.STRUCT, + TokenType.NULLABLE, + } + + TYPE_TOKENS = { + TokenType.BOOLEAN, + TokenType.TINYINT, + TokenType.SMALLINT, + TokenType.INT, + TokenType.BIGINT, + TokenType.FLOAT, + TokenType.DOUBLE, + TokenType.CHAR, + TokenType.NCHAR, + TokenType.VARCHAR, + TokenType.NVARCHAR, + TokenType.TEXT, + TokenType.BINARY, + TokenType.JSON, + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + TokenType.DATETIME, + TokenType.DATE, + TokenType.DECIMAL, + TokenType.UUID, + TokenType.GEOGRAPHY, + *NESTED_TYPE_TOKENS, + } + + SUBQUERY_PREDICATES = { + TokenType.ANY: exp.Any, + TokenType.ALL: exp.All, + TokenType.EXISTS: exp.Exists, + TokenType.SOME: exp.Any, + } + + RESERVED_KEYWORDS = {*Tokenizer.SINGLE_TOKENS.values(), TokenType.SELECT} + + ID_VAR_TOKENS = { + TokenType.VAR, + TokenType.ALTER, + TokenType.BEGIN, + TokenType.BUCKET, + TokenType.CACHE, + TokenType.COLLATE, + TokenType.COMMIT, + TokenType.CONSTRAINT, + TokenType.CONVERT, + TokenType.DEFAULT, + TokenType.DELETE, + TokenType.ENGINE, + TokenType.ESCAPE, + TokenType.EXPLAIN, + TokenType.FALSE, + TokenType.FIRST, + TokenType.FOLLOWING, + TokenType.FORMAT, + TokenType.FUNCTION, + TokenType.IF, + TokenType.INDEX, + TokenType.ISNULL, + TokenType.INTERVAL, + TokenType.LAZY, + TokenType.LOCATION, + TokenType.NEXT, + TokenType.ONLY, + TokenType.OPTIMIZE, + TokenType.OPTIONS, + TokenType.ORDINALITY, + TokenType.PERCENT, + TokenType.PRECEDING, + TokenType.RANGE, + TokenType.REFERENCES, + TokenType.ROWS, + TokenType.SCHEMA_COMMENT, + TokenType.SET, + TokenType.SHOW, + TokenType.STORED, + TokenType.TABLE, + TokenType.TABLE_FORMAT, + TokenType.TEMPORARY, + TokenType.TOP, + TokenType.TRUNCATE, + TokenType.TRUE, + TokenType.UNBOUNDED, + TokenType.UNIQUE, + TokenType.PROPERTIES, + *SUBQUERY_PREDICATES, + *TYPE_TOKENS, + } + + CASTS = { + TokenType.CAST, + TokenType.TRY_CAST, + } + + FUNC_TOKENS = { + TokenType.CONVERT, + TokenType.CURRENT_DATE, + TokenType.CURRENT_DATETIME, + TokenType.CURRENT_TIMESTAMP, + TokenType.CURRENT_TIME, + TokenType.EXTRACT, + TokenType.FILTER, + TokenType.FIRST, + TokenType.FORMAT, + TokenType.ISNULL, + TokenType.OFFSET, + TokenType.PRIMARY_KEY, + TokenType.REPLACE, + TokenType.ROW, + TokenType.UNNEST, + TokenType.VAR, + TokenType.LEFT, + TokenType.RIGHT, + TokenType.DATE, + TokenType.DATETIME, + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + *CASTS, + *NESTED_TYPE_TOKENS, + *SUBQUERY_PREDICATES, + } + + CONJUNCTION = { + TokenType.AND: exp.And, + TokenType.OR: exp.Or, + } + + EQUALITY = { + TokenType.EQ: exp.EQ, + TokenType.NEQ: exp.NEQ, + } + + COMPARISON = { + TokenType.GT: exp.GT, + TokenType.GTE: exp.GTE, + TokenType.LT: exp.LT, + TokenType.LTE: exp.LTE, + } + + BITWISE = { + TokenType.AMP: exp.BitwiseAnd, + TokenType.CARET: exp.BitwiseXor, + TokenType.PIPE: exp.BitwiseOr, + TokenType.DPIPE: exp.DPipe, + } + + TERM = { + TokenType.DASH: exp.Sub, + TokenType.PLUS: exp.Add, + TokenType.MOD: exp.Mod, + } + + FACTOR = { + TokenType.DIV: exp.IntDiv, + TokenType.SLASH: exp.Div, + TokenType.STAR: exp.Mul, + } + + TIMESTAMPS = { + TokenType.TIMESTAMP, + TokenType.TIMESTAMPTZ, + } + + SET_OPERATIONS = { + TokenType.UNION, + TokenType.INTERSECT, + TokenType.EXCEPT, + } + + JOIN_SIDES = { + TokenType.LEFT, + TokenType.RIGHT, + TokenType.FULL, + } + + JOIN_KINDS = { + TokenType.INNER, + TokenType.OUTER, + TokenType.CROSS, + } + + COLUMN_OPERATORS = { + TokenType.DOT: None, + TokenType.ARROW: lambda self, this, path: self.expression( + exp.JSONExtract, + this=this, + path=path, + ), + TokenType.DARROW: lambda self, this, path: self.expression( + exp.JSONExtractScalar, + this=this, + path=path, + ), + TokenType.HASH_ARROW: lambda self, this, path: self.expression( + exp.JSONBExtract, + this=this, + path=path, + ), + TokenType.DHASH_ARROW: lambda self, this, path: self.expression( + exp.JSONBExtractScalar, + this=this, + path=path, + ), + } + + EXPRESSION_PARSERS = { + exp.DataType: lambda self: self._parse_types(), + exp.From: lambda self: self._parse_from(), + exp.Group: lambda self: self._parse_group(), + exp.Lateral: lambda self: self._parse_lateral(), + exp.Join: lambda self: self._parse_join(), + exp.Order: lambda self: self._parse_order(), + exp.Cluster: lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), + exp.Sort: lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), + exp.Lambda: lambda self: self._parse_lambda(), + exp.Limit: lambda self: self._parse_limit(), + exp.Offset: lambda self: self._parse_offset(), + exp.TableAlias: lambda self: self._parse_table_alias(), + exp.Table: lambda self: self._parse_table(), + exp.Condition: lambda self: self._parse_conjunction(), + exp.Expression: lambda self: self._parse_statement(), + exp.Properties: lambda self: self._parse_properties(), + "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), + } + + STATEMENT_PARSERS = { + TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DROP: lambda self: self._parse_drop(), + TokenType.INSERT: lambda self: self._parse_insert(), + TokenType.UPDATE: lambda self: self._parse_update(), + TokenType.DELETE: lambda self: self._parse_delete(), + TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.UNCACHE: lambda self: self._parse_uncache(), + } + + PRIMARY_PARSERS = { + TokenType.STRING: lambda _, token: exp.Literal.string(token.text), + TokenType.NUMBER: lambda _, token: exp.Literal.number(token.text), + TokenType.STAR: lambda self, _: exp.Star( + **{"except": self._parse_except(), "replace": self._parse_replace()} + ), + TokenType.NULL: lambda *_: exp.Null(), + TokenType.TRUE: lambda *_: exp.Boolean(this=True), + TokenType.FALSE: lambda *_: exp.Boolean(this=False), + TokenType.PLACEHOLDER: lambda *_: exp.Placeholder(), + TokenType.BIT_STRING: lambda _, token: exp.BitString(this=token.text), + TokenType.INTRODUCER: lambda self, token: self.expression( + exp.Introducer, + this=token.text, + expression=self._parse_var_or_string(), + ), + } + + RANGE_PARSERS = { + TokenType.BETWEEN: lambda self, this: self._parse_between(this), + TokenType.IN: lambda self, this: self._parse_in(this), + TokenType.IS: lambda self, this: self._parse_is(this), + TokenType.LIKE: lambda self, this: self._parse_escape( + self.expression(exp.Like, this=this, expression=self._parse_type()) + ), + TokenType.ILIKE: lambda self, this: self._parse_escape( + self.expression(exp.ILike, this=this, expression=self._parse_type()) + ), + TokenType.RLIKE: lambda self, this: self.expression( + exp.RegexpLike, this=this, expression=self._parse_type() + ), + } + + PROPERTY_PARSERS = { + TokenType.AUTO_INCREMENT: lambda self: self._parse_auto_increment(), + TokenType.CHARACTER_SET: lambda self: self._parse_character_set(), + TokenType.COLLATE: lambda self: self._parse_collate(), + TokenType.ENGINE: lambda self: self._parse_engine(), + TokenType.FORMAT: lambda self: self._parse_format(), + TokenType.LOCATION: lambda self: self.expression( + exp.LocationProperty, + this=exp.Literal.string("LOCATION"), + value=self._parse_string(), + ), + TokenType.PARTITIONED_BY: lambda self: self.expression( + exp.PartitionedByProperty, + this=exp.Literal.string("PARTITIONED_BY"), + value=self._parse_schema(), + ), + TokenType.SCHEMA_COMMENT: lambda self: self._parse_schema_comment(), + TokenType.STORED: lambda self: self._parse_stored(), + TokenType.TABLE_FORMAT: lambda self: self._parse_table_format(), + TokenType.USING: lambda self: self._parse_table_format(), + } + + CONSTRAINT_PARSERS = { + TokenType.CHECK: lambda self: self._parse_check(), + TokenType.FOREIGN_KEY: lambda self: self._parse_foreign_key(), + TokenType.UNIQUE: lambda self: self._parse_unique(), + } + + NO_PAREN_FUNCTION_PARSERS = { + TokenType.CASE: lambda self: self._parse_case(), + TokenType.IF: lambda self: self._parse_if(), + } + + FUNCTION_PARSERS = { + TokenType.CONVERT: lambda self, _: self._parse_convert(), + TokenType.EXTRACT: lambda self, _: self._parse_extract(), + **{ + token_type: lambda self, token_type: self._parse_cast( + self.STRICT_CAST and token_type == TokenType.CAST + ) + for token_type in CASTS + }, + } + + QUERY_MODIFIER_PARSERS = { + "laterals": lambda self: self._parse_laterals(), + "joins": lambda self: self._parse_joins(), + "where": lambda self: self._parse_where(), + "group": lambda self: self._parse_group(), + "having": lambda self: self._parse_having(), + "qualify": lambda self: self._parse_qualify(), + "window": lambda self: self._match(TokenType.WINDOW) + and self._parse_window(self._parse_id_var(), alias=True), + "distribute": lambda self: self._parse_sort( + TokenType.DISTRIBUTE_BY, exp.Distribute + ), + "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), + "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), + "order": lambda self: self._parse_order(), + "limit": lambda self: self._parse_limit(), + "offset": lambda self: self._parse_offset(), + } + + CREATABLES = {TokenType.TABLE, TokenType.VIEW, TokenType.FUNCTION, TokenType.INDEX} + + STRICT_CAST = True + + __slots__ = ( + "error_level", + "error_message_context", + "sql", + "errors", + "index_offset", + "unnest_column_only", + "alias_post_tablesample", + "max_errors", + "null_ordering", + "_tokens", + "_chunks", + "_index", + "_curr", + "_next", + "_prev", + "_greedy_subqueries", + ) + + def __init__( + self, + error_level=None, + error_message_context=100, + index_offset=0, + unnest_column_only=False, + alias_post_tablesample=False, + max_errors=3, + null_ordering=None, + ): + self.error_level = error_level or ErrorLevel.RAISE + self.error_message_context = error_message_context + self.index_offset = index_offset + self.unnest_column_only = unnest_column_only + self.alias_post_tablesample = alias_post_tablesample + self.max_errors = max_errors + self.null_ordering = null_ordering + self.reset() + + def reset(self): + self.sql = "" + self.errors = [] + self._tokens = [] + self._chunks = [[]] + self._index = 0 + self._curr = None + self._next = None + self._prev = None + self._greedy_subqueries = False + + def parse(self, raw_tokens, sql=None): + """ + Parses the given list of tokens and returns a list of syntax trees, one tree + per parsed SQL statement. + + Args + raw_tokens (list): the list of tokens (:class:`~sqlglot.tokens.Token`). + sql (str): the original SQL string. Used to produce helpful debug messages. + + Returns + the list of syntax trees (:class:`~sqlglot.expressions.Expression`). + """ + return self._parse( + parse_method=self.__class__._parse_statement, raw_tokens=raw_tokens, sql=sql + ) + + def parse_into(self, expression_types, raw_tokens, sql=None): + for expression_type in ensure_list(expression_types): + parser = self.EXPRESSION_PARSERS.get(expression_type) + if not parser: + raise TypeError(f"No parser registered for {expression_type}") + try: + return self._parse(parser, raw_tokens, sql) + except ParseError as e: + error = e + raise ParseError(f"Failed to parse into {expression_types}") from error + + def _parse(self, parse_method, raw_tokens, sql=None): + self.reset() + self.sql = sql or "" + total = len(raw_tokens) + + for i, token in enumerate(raw_tokens): + if token.token_type == TokenType.SEMICOLON: + if i < total - 1: + self._chunks.append([]) + else: + self._chunks[-1].append(token) + + expressions = [] + + for tokens in self._chunks: + self._index = -1 + self._tokens = tokens + self._advance() + expressions.append(parse_method(self)) + + if self._index < len(self._tokens): + self.raise_error("Invalid expression / Unexpected token") + + self.check_errors() + + return expressions + + def check_errors(self): + if self.error_level == ErrorLevel.WARN: + for error in self.errors: + logger.error(str(error)) + elif self.error_level == ErrorLevel.RAISE and self.errors: + raise ParseError(concat_errors(self.errors, self.max_errors)) + + def raise_error(self, message, token=None): + token = token or self._curr or self._prev or Token.string("") + start = self._find_token(token, self.sql) + end = start + len(token.text) + start_context = self.sql[max(start - self.error_message_context, 0) : start] + highlight = self.sql[start:end] + end_context = self.sql[end : end + self.error_message_context] + error = ParseError( + f"{message}. Line {token.line}, Col: {token.col}.\n" + f" {start_context}\033[4m{highlight}\033[0m{end_context}" + ) + if self.error_level == ErrorLevel.IMMEDIATE: + raise error + self.errors.append(error) + + def expression(self, exp_class, **kwargs): + instance = exp_class(**kwargs) + self.validate_expression(instance) + return instance + + def validate_expression(self, expression, args=None): + if self.error_level == ErrorLevel.IGNORE: + return + + for k in expression.args: + if k not in expression.arg_types: + self.raise_error( + f"Unexpected keyword: '{k}' for {expression.__class__}" + ) + for k, mandatory in expression.arg_types.items(): + v = expression.args.get(k) + if mandatory and (v is None or (isinstance(v, list) and not v)): + self.raise_error( + f"Required keyword: '{k}' missing for {expression.__class__}" + ) + + if ( + args + and len(args) > len(expression.arg_types) + and not expression.is_var_len_args + ): + self.raise_error( + f"The number of provided arguments ({len(args)}) is greater than " + f"the maximum number of supported arguments ({len(expression.arg_types)})" + ) + + def _find_token(self, token, sql): + line = 1 + col = 1 + index = 0 + + while line < token.line or col < token.col: + if Tokenizer.WHITE_SPACE.get(sql[index]) == TokenType.BREAK: + line += 1 + col = 1 + else: + col += 1 + index += 1 + + return index + + def _get_token(self, index): + return list_get(self._tokens, index) + + def _advance(self, times=1): + self._index += times + self._curr = self._get_token(self._index) + self._next = self._get_token(self._index + 1) + self._prev = self._get_token(self._index - 1) if self._index > 0 else None + + def _retreat(self, index): + self._advance(index - self._index) + + def _parse_statement(self): + if self._curr is None: + return None + + if self._match_set(self.STATEMENT_PARSERS): + return self.STATEMENT_PARSERS[self._prev.token_type](self) + + if self._match_set(Tokenizer.COMMANDS): + return self.expression( + exp.Command, + this=self._prev.text, + expression=self._parse_string(), + ) + + expression = self._parse_expression() + expression = ( + self._parse_set_operations(expression) + if expression + else self._parse_select() + ) + self._parse_query_modifiers(expression) + return expression + + def _parse_drop(self): + if self._match(TokenType.TABLE): + kind = "TABLE" + elif self._match(TokenType.VIEW): + kind = "VIEW" + else: + self.raise_error("Expected TABLE or View") + + return self.expression( + exp.Drop, + exists=self._parse_exists(), + this=self._parse_table(schema=True), + kind=kind, + ) + + def _parse_exists(self, not_=False): + return ( + self._match(TokenType.IF) + and (not not_ or self._match(TokenType.NOT)) + and self._match(TokenType.EXISTS) + ) + + def _parse_create(self): + replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) + temporary = self._match(TokenType.TEMPORARY) + unique = self._match(TokenType.UNIQUE) + + create_token = self._match_set(self.CREATABLES) and self._prev + + if not create_token: + self.raise_error("Expected TABLE, VIEW, INDEX, or FUNCTION") + + exists = self._parse_exists(not_=True) + this = None + expression = None + properties = None + + if create_token.token_type == TokenType.FUNCTION: + this = self._parse_var() + if self._match(TokenType.ALIAS): + expression = self._parse_string() + elif create_token.token_type == TokenType.INDEX: + this = self._parse_index() + elif create_token.token_type in (TokenType.TABLE, TokenType.VIEW): + this = self._parse_table(schema=True) + properties = self._parse_properties( + this if isinstance(this, exp.Schema) else None + ) + if self._match(TokenType.ALIAS): + expression = self._parse_select() + + return self.expression( + exp.Create, + this=this, + kind=create_token.text, + expression=expression, + exists=exists, + properties=properties, + temporary=temporary, + replace=replace, + unique=unique, + ) + + def _parse_property(self, schema): + if self._match_set(self.PROPERTY_PARSERS): + return self.PROPERTY_PARSERS[self._prev.token_type](self) + if self._match_pair(TokenType.DEFAULT, TokenType.CHARACTER_SET): + return self._parse_character_set(True) + + if self._match_pair(TokenType.VAR, TokenType.EQ, advance=False): + key = self._parse_var().this + self._match(TokenType.EQ) + + if key.upper() == "PARTITIONED_BY": + expression = exp.PartitionedByProperty + value = self._parse_schema() or self._parse_bracket(self._parse_field()) + + if schema and not isinstance(value, exp.Schema): + columns = {v.name.upper() for v in value.expressions} + partitions = [ + expression + for expression in schema.expressions + if expression.this.name.upper() in columns + ] + schema.set( + "expressions", + [e for e in schema.expressions if e not in partitions], + ) + value = self.expression(exp.Schema, expressions=partitions) + else: + value = self._parse_column() + expression = exp.AnonymousProperty + + return self.expression( + expression, + this=exp.Literal.string(key), + value=value, + ) + return None + + def _parse_stored(self): + self._match(TokenType.ALIAS) + self._match(TokenType.EQ) + return self.expression( + exp.FileFormatProperty, + this=exp.Literal.string("FORMAT"), + value=exp.Literal.string(self._parse_var().name), + ) + + def _parse_format(self): + self._match(TokenType.EQ) + return self.expression( + exp.FileFormatProperty, + this=exp.Literal.string("FORMAT"), + value=self._parse_string() or self._parse_var(), + ) + + def _parse_engine(self): + self._match(TokenType.EQ) + return self.expression( + exp.EngineProperty, + this=exp.Literal.string("ENGINE"), + value=self._parse_var_or_string(), + ) + + def _parse_auto_increment(self): + self._match(TokenType.EQ) + return self.expression( + exp.AutoIncrementProperty, + this=exp.Literal.string("AUTO_INCREMENT"), + value=self._parse_var() or self._parse_number(), + ) + + def _parse_collate(self): + self._match(TokenType.EQ) + return self.expression( + exp.CollateProperty, + this=exp.Literal.string("COLLATE"), + value=self._parse_var_or_string(), + ) + + def _parse_schema_comment(self): + self._match(TokenType.EQ) + return self.expression( + exp.SchemaCommentProperty, + this=exp.Literal.string("COMMENT"), + value=self._parse_string(), + ) + + def _parse_character_set(self, default=False): + self._match(TokenType.EQ) + return self.expression( + exp.CharacterSetProperty, + this=exp.Literal.string("CHARACTER_SET"), + value=self._parse_var_or_string(), + default=default, + ) + + def _parse_table_format(self): + self._match(TokenType.EQ) + return self.expression( + exp.TableFormatProperty, + this=exp.Literal.string("TABLE_FORMAT"), + value=self._parse_var_or_string(), + ) + + def _parse_properties(self, schema=None): + """ + Schema is included since if the table schema is defined and we later get a partition by expression + then we will define those columns in the partition by section and not in with the rest of the + columns + """ + properties = [] + + while True: + if self._match(TokenType.WITH): + self._match_l_paren() + properties.extend(self._parse_csv(lambda: self._parse_property(schema))) + self._match_r_paren() + elif self._match(TokenType.PROPERTIES): + self._match_l_paren() + properties.extend( + self._parse_csv( + lambda: self.expression( + exp.AnonymousProperty, + this=self._parse_string(), + value=self._match(TokenType.EQ) and self._parse_string(), + ) + ) + ) + self._match_r_paren() + else: + identified_property = self._parse_property(schema) + if not identified_property: + break + properties.append(identified_property) + if properties: + return self.expression(exp.Properties, expressions=properties) + return None + + def _parse_insert(self): + overwrite = self._match(TokenType.OVERWRITE) + self._match(TokenType.INTO) + self._match(TokenType.TABLE) + return self.expression( + exp.Insert, + this=self._parse_table(schema=True), + exists=self._parse_exists(), + partition=self._parse_partition(), + expression=self._parse_select(), + overwrite=overwrite, + ) + + def _parse_delete(self): + self._match(TokenType.FROM) + + return self.expression( + exp.Delete, + this=self._parse_table(schema=True), + where=self._parse_where(), + ) + + def _parse_update(self): + return self.expression( + exp.Update, + **{ + "this": self._parse_table(schema=True), + "expressions": self._match(TokenType.SET) + and self._parse_csv(self._parse_equality), + "from": self._parse_from(), + "where": self._parse_where(), + }, + ) + + def _parse_uncache(self): + if not self._match(TokenType.TABLE): + self.raise_error("Expecting TABLE after UNCACHE") + return self.expression( + exp.Uncache, + exists=self._parse_exists(), + this=self._parse_table(schema=True), + ) + + def _parse_cache(self): + lazy = self._match(TokenType.LAZY) + self._match(TokenType.TABLE) + table = self._parse_table(schema=True) + options = [] + + if self._match(TokenType.OPTIONS): + self._match_l_paren() + k = self._parse_string() + self._match(TokenType.EQ) + v = self._parse_string() + options = [k, v] + self._match_r_paren() + + self._match(TokenType.ALIAS) + return self.expression( + exp.Cache, + this=table, + lazy=lazy, + options=options, + expression=self._parse_select(), + ) + + def _parse_partition(self): + if not self._match(TokenType.PARTITION): + return None + + def parse_values(): + k = self._parse_var() + if self._match(TokenType.EQ): + v = self._parse_string() + return (k, v) + return (k, None) + + self._match_l_paren() + values = self._parse_csv(parse_values) + self._match_r_paren() + + return self.expression( + exp.Partition, + this=values, + ) + + def _parse_value(self): + self._match_l_paren() + expressions = self._parse_csv(self._parse_conjunction) + self._match_r_paren() + return self.expression(exp.Tuple, expressions=expressions) + + def _parse_select(self, table=None): + index = self._index + + if self._match(TokenType.SELECT): + hint = self._parse_hint() + all_ = self._match(TokenType.ALL) + distinct = self._match(TokenType.DISTINCT) + + if distinct: + distinct = self.expression( + exp.Distinct, + on=self._parse_value() if self._match(TokenType.ON) else None, + ) + + if all_ and distinct: + self.raise_error("Cannot specify both ALL and DISTINCT after SELECT") + + limit = self._parse_limit(top=True) + expressions = self._parse_csv( + lambda: self._parse_annotation(self._parse_expression()) + ) + + this = self.expression( + exp.Select, + hint=hint, + distinct=distinct, + expressions=expressions, + limit=limit, + ) + from_ = self._parse_from() + if from_: + this.set("from", from_) + self._parse_query_modifiers(this) + elif self._match(TokenType.WITH): + recursive = self._match(TokenType.RECURSIVE) + + expressions = [] + + while True: + expressions.append(self._parse_cte()) + + if not self._match(TokenType.COMMA): + break + + cte = self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ) + this = self._parse_statement() + + if not this: + self.raise_error("Failed to parse any statement following CTE") + return cte + + if "with" in this.arg_types: + this.set( + "with", + self.expression( + exp.With, + expressions=expressions, + recursive=recursive, + ), + ) + else: + self.raise_error(f"{this.key} does not support CTE") + elif self._match(TokenType.L_PAREN): + this = self._parse_table() if table else self._parse_select() + + if this: + self._parse_query_modifiers(this) + self._match_r_paren() + this = self._parse_subquery(this) + else: + self._retreat(index) + elif self._match(TokenType.VALUES): + this = self.expression( + exp.Values, expressions=self._parse_csv(self._parse_value) + ) + alias = self._parse_table_alias() + if alias: + this = self.expression(exp.Subquery, this=this, alias=alias) + else: + this = None + + return self._parse_set_operations(this) if this else None + + def _parse_cte(self): + alias = self._parse_table_alias() + if not alias or not alias.this: + self.raise_error("Expected CTE to have alias") + + if not self._match(TokenType.ALIAS): + self.raise_error("Expected AS in CTE") + + self._match_l_paren() + expression = self._parse_statement() + self._match_r_paren() + + return self.expression( + exp.CTE, + this=expression, + alias=alias, + ) + + def _parse_table_alias(self): + any_token = self._match(TokenType.ALIAS) + alias = self._parse_id_var(any_token) + columns = None + + if self._match(TokenType.L_PAREN): + columns = self._parse_csv(lambda: self._parse_id_var(any_token)) + self._match_r_paren() + + if not alias and not columns: + return None + + return self.expression( + exp.TableAlias, + this=alias, + columns=columns, + ) + + def _parse_subquery(self, this): + return self.expression(exp.Subquery, this=this, alias=self._parse_table_alias()) + + def _parse_query_modifiers(self, this): + if not isinstance(this, (exp.Subquery, exp.Subqueryable)): + return + + for key, parser in self.QUERY_MODIFIER_PARSERS.items(): + expression = parser(self) + + if expression: + this.set(key, expression) + + def _parse_annotation(self, expression): + if self._match(TokenType.ANNOTATION): + return self.expression( + exp.Annotation, this=self._prev.text, expression=expression + ) + + return expression + + def _parse_hint(self): + if self._match(TokenType.HINT): + hints = self._parse_csv(self._parse_function) + if not self._match(TokenType.HINT): + self.raise_error("Expected */ after HINT") + return self.expression(exp.Hint, expressions=hints) + return None + + def _parse_from(self): + if not self._match(TokenType.FROM): + return None + + return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) + + def _parse_laterals(self): + return self._parse_all(self._parse_lateral) + + def _parse_lateral(self): + if not self._match(TokenType.LATERAL): + return None + + if not self._match(TokenType.VIEW): + self.raise_error("Expected VIEW after LATERAL") + + outer = self._match(TokenType.OUTER) + + return self.expression( + exp.Lateral, + this=self._parse_function(), + outer=outer, + alias=self.expression( + exp.TableAlias, + this=self._parse_id_var(any_token=False), + columns=( + self._parse_csv(self._parse_id_var) + if self._match(TokenType.ALIAS) + else None + ), + ), + ) + + def _parse_joins(self): + return self._parse_all(self._parse_join) + + def _parse_join_side_and_kind(self): + return ( + self._match_set(self.JOIN_SIDES) and self._prev, + self._match_set(self.JOIN_KINDS) and self._prev, + ) + + def _parse_join(self): + side, kind = self._parse_join_side_and_kind() + + if not self._match(TokenType.JOIN): + return None + + kwargs = {"this": self._parse_table()} + + if side: + kwargs["side"] = side.text + if kind: + kwargs["kind"] = kind.text + + if self._match(TokenType.ON): + kwargs["on"] = self._parse_conjunction() + elif self._match(TokenType.USING): + kwargs["using"] = self._parse_wrapped_id_vars() + + return self.expression(exp.Join, **kwargs) + + def _parse_index(self): + index = self._parse_id_var() + self._match(TokenType.ON) + self._match(TokenType.TABLE) # hive + return self.expression( + exp.Index, + this=index, + table=self.expression(exp.Table, this=self._parse_id_var()), + columns=self._parse_expression(), + ) + + def _parse_table(self, schema=False): + unnest = self._parse_unnest() + + if unnest: + return unnest + + subquery = self._parse_select(table=True) + + if subquery: + return subquery + + catalog = None + db = None + table = (not schema and self._parse_function()) or self._parse_id_var(False) + + while self._match(TokenType.DOT): + catalog = db + db = table + table = self._parse_id_var() + + if not table: + self.raise_error("Expected table name") + + this = self.expression(exp.Table, this=table, db=db, catalog=catalog) + + if schema: + return self._parse_schema(this=this) + + if self.alias_post_tablesample: + table_sample = self._parse_table_sample() + + alias = self._parse_table_alias() + + if alias: + this = self.expression(exp.Alias, this=this, alias=alias) + + if not self.alias_post_tablesample: + table_sample = self._parse_table_sample() + + if table_sample: + table_sample.set("this", this) + this = table_sample + + return this + + def _parse_unnest(self): + if not self._match(TokenType.UNNEST): + return None + + self._match_l_paren() + expressions = self._parse_csv(self._parse_column) + self._match_r_paren() + + ordinality = bool( + self._match(TokenType.WITH) and self._match(TokenType.ORDINALITY) + ) + + alias = self._parse_table_alias() + + if alias and self.unnest_column_only: + if alias.args.get("columns"): + self.raise_error("Unexpected extra column alias in unnest.") + alias.set("columns", [alias.this]) + alias.set("this", None) + + return self.expression( + exp.Unnest, + expressions=expressions, + ordinality=ordinality, + alias=alias, + ) + + def _parse_table_sample(self): + if not self._match(TokenType.TABLE_SAMPLE): + return None + + method = self._parse_var() + bucket_numerator = None + bucket_denominator = None + bucket_field = None + percent = None + rows = None + size = None + + self._match_l_paren() + + if self._match(TokenType.BUCKET): + bucket_numerator = self._parse_number() + self._match(TokenType.OUT_OF) + bucket_denominator = bucket_denominator = self._parse_number() + self._match(TokenType.ON) + bucket_field = self._parse_field() + else: + num = self._parse_number() + + if self._match(TokenType.PERCENT): + percent = num + elif self._match(TokenType.ROWS): + rows = num + else: + size = num + + self._match_r_paren() + + return self.expression( + exp.TableSample, + method=method, + bucket_numerator=bucket_numerator, + bucket_denominator=bucket_denominator, + bucket_field=bucket_field, + percent=percent, + rows=rows, + size=size, + ) + + def _parse_where(self): + if not self._match(TokenType.WHERE): + return None + return self.expression(exp.Where, this=self._parse_conjunction()) + + def _parse_group(self): + if not self._match(TokenType.GROUP_BY): + return None + return self.expression( + exp.Group, + expressions=self._parse_csv(self._parse_conjunction), + grouping_sets=self._parse_grouping_sets(), + cube=self._match(TokenType.CUBE) and self._parse_wrapped_id_vars(), + rollup=self._match(TokenType.ROLLUP) and self._parse_wrapped_id_vars(), + ) + + def _parse_grouping_sets(self): + if not self._match(TokenType.GROUPING_SETS): + return None + + self._match_l_paren() + grouping_sets = self._parse_csv(self._parse_grouping_set) + self._match_r_paren() + return grouping_sets + + def _parse_grouping_set(self): + if self._match(TokenType.L_PAREN): + grouping_set = self._parse_csv(self._parse_id_var) + self._match_r_paren() + return self.expression(exp.Tuple, expressions=grouping_set) + return self._parse_id_var() + + def _parse_having(self): + if not self._match(TokenType.HAVING): + return None + return self.expression(exp.Having, this=self._parse_conjunction()) + + def _parse_qualify(self): + if not self._match(TokenType.QUALIFY): + return None + return self.expression(exp.Qualify, this=self._parse_conjunction()) + + def _parse_order(self, this=None): + if not self._match(TokenType.ORDER_BY): + return this + + return self.expression( + exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered) + ) + + def _parse_sort(self, token_type, exp_class): + if not self._match(token_type): + return None + + return self.expression( + exp_class, expressions=self._parse_csv(self._parse_ordered) + ) + + def _parse_ordered(self): + this = self._parse_conjunction() + self._match(TokenType.ASC) + is_desc = self._match(TokenType.DESC) + is_nulls_first = self._match(TokenType.NULLS_FIRST) + is_nulls_last = self._match(TokenType.NULLS_LAST) + desc = is_desc or False + asc = not desc + nulls_first = is_nulls_first or False + explicitly_null_ordered = is_nulls_first or is_nulls_last + if ( + not explicitly_null_ordered + and ( + (asc and self.null_ordering == "nulls_are_small") + or (desc and self.null_ordering != "nulls_are_small") + ) + and self.null_ordering != "nulls_are_last" + ): + nulls_first = True + + return self.expression( + exp.Ordered, this=this, desc=desc, nulls_first=nulls_first + ) + + def _parse_limit(self, this=None, top=False): + if self._match(TokenType.TOP if top else TokenType.LIMIT): + return self.expression( + exp.Limit, this=this, expression=self._parse_number() + ) + if self._match(TokenType.FETCH): + direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) + direction = self._prev.text if direction else "FIRST" + count = self._parse_number() + self._match_set((TokenType.ROW, TokenType.ROWS)) + self._match(TokenType.ONLY) + return self.expression(exp.Fetch, direction=direction, count=count) + return this + + def _parse_offset(self, this=None): + if not self._match(TokenType.OFFSET): + return this + count = self._parse_number() + self._match_set((TokenType.ROW, TokenType.ROWS)) + return self.expression(exp.Offset, this=this, expression=count) + + def _parse_set_operations(self, this): + if not self._match_set(self.SET_OPERATIONS): + return this + + token_type = self._prev.token_type + + if token_type == TokenType.UNION: + expression = exp.Union + elif token_type == TokenType.EXCEPT: + expression = exp.Except + else: + expression = exp.Intersect + + return self.expression( + expression, + this=this, + distinct=self._match(TokenType.DISTINCT) or not self._match(TokenType.ALL), + expression=self._parse_select(), + ) + + def _parse_expression(self): + return self._parse_alias(self._parse_conjunction()) + + def _parse_conjunction(self): + return self._parse_tokens(self._parse_equality, self.CONJUNCTION) + + def _parse_equality(self): + return self._parse_tokens(self._parse_comparison, self.EQUALITY) + + def _parse_comparison(self): + return self._parse_tokens(self._parse_range, self.COMPARISON) + + def _parse_range(self): + this = self._parse_bitwise() + negate = self._match(TokenType.NOT) + + if self._match_set(self.RANGE_PARSERS): + this = self.RANGE_PARSERS[self._prev.token_type](self, this) + + if negate: + this = self.expression(exp.Not, this=this) + + return this + + def _parse_is(self, this): + negate = self._match(TokenType.NOT) + this = self.expression( + exp.Is, + this=this, + expression=self._parse_null() or self._parse_boolean(), + ) + return self.expression(exp.Not, this=this) if negate else this + + def _parse_in(self, this): + unnest = self._parse_unnest() + if unnest: + this = self.expression(exp.In, this=this, unnest=unnest) + else: + self._match_l_paren() + expressions = self._parse_csv( + lambda: self._parse_select() or self._parse_expression() + ) + + if len(expressions) == 1 and isinstance(expressions[0], exp.Subqueryable): + this = self.expression(exp.In, this=this, query=expressions[0]) + else: + this = self.expression(exp.In, this=this, expressions=expressions) + + self._match_r_paren() + return this + + def _parse_between(self, this): + low = self._parse_bitwise() + self._match(TokenType.AND) + high = self._parse_bitwise() + return self.expression(exp.Between, this=this, low=low, high=high) + + def _parse_escape(self, this): + if not self._match(TokenType.ESCAPE): + return this + return self.expression(exp.Escape, this=this, expression=self._parse_string()) + + def _parse_bitwise(self): + this = self._parse_term() + + while True: + if self._match_set(self.BITWISE): + this = self.expression( + self.BITWISE[self._prev.token_type], + this=this, + expression=self._parse_term(), + ) + elif self._match_pair(TokenType.LT, TokenType.LT): + this = self.expression( + exp.BitwiseLeftShift, this=this, expression=self._parse_term() + ) + elif self._match_pair(TokenType.GT, TokenType.GT): + this = self.expression( + exp.BitwiseRightShift, this=this, expression=self._parse_term() + ) + else: + break + + return this + + def _parse_term(self): + return self._parse_tokens(self._parse_factor, self.TERM) + + def _parse_factor(self): + return self._parse_tokens(self._parse_unary, self.FACTOR) + + def _parse_unary(self): + if self._match(TokenType.NOT): + return self.expression(exp.Not, this=self._parse_equality()) + if self._match(TokenType.TILDA): + return self.expression(exp.BitwiseNot, this=self._parse_unary()) + if self._match(TokenType.DASH): + return self.expression(exp.Neg, this=self._parse_unary()) + return self._parse_at_time_zone(self._parse_type()) + + def _parse_type(self): + if self._match(TokenType.INTERVAL): + return self.expression( + exp.Interval, + this=self._parse_term(), + unit=self._parse_var(), + ) + + index = self._index + type_token = self._parse_types() + this = self._parse_column() + + if type_token: + if this: + return self.expression(exp.Cast, this=this, to=type_token) + if not type_token.args.get("expressions"): + self._retreat(index) + return self._parse_column() + return type_token + + while self._match(TokenType.DCOLON): + type_token = self._parse_types() + if not type_token: + self.raise_error("Expected type") + this = self.expression(exp.Cast, this=this, to=type_token) + + return this + + def _parse_types(self): + index = self._index + + if not self._match_set(self.TYPE_TOKENS): + return None + + type_token = self._prev.token_type + nested = type_token in self.NESTED_TYPE_TOKENS + is_struct = type_token == TokenType.STRUCT + expressions = None + + if self._match(TokenType.L_BRACKET): + self._retreat(index) + return None + + if self._match(TokenType.L_PAREN): + if is_struct: + expressions = self._parse_csv(self._parse_struct_kwargs) + elif nested: + expressions = self._parse_csv(self._parse_types) + else: + expressions = self._parse_csv(self._parse_number) + + if not expressions: + self._retreat(index) + return None + + self._match_r_paren() + + if nested and self._match(TokenType.LT): + if is_struct: + expressions = self._parse_csv(self._parse_struct_kwargs) + else: + expressions = self._parse_csv(self._parse_types) + + if not self._match(TokenType.GT): + self.raise_error("Expecting >") + + if type_token in self.TIMESTAMPS: + tz = self._match(TokenType.WITH_TIME_ZONE) + self._match(TokenType.WITHOUT_TIME_ZONE) + if tz: + return exp.DataType( + this=exp.DataType.Type.TIMESTAMPTZ, + expressions=expressions, + ) + return exp.DataType( + this=exp.DataType.Type.TIMESTAMP, + expressions=expressions, + ) + + return exp.DataType( + this=exp.DataType.Type[type_token.value.upper()], + expressions=expressions, + nested=nested, + ) + + def _parse_struct_kwargs(self): + this = self._parse_id_var() + self._match(TokenType.COLON) + data_type = self._parse_types() + if not data_type: + return None + return self.expression(exp.StructKwarg, this=this, expression=data_type) + + def _parse_at_time_zone(self, this): + if not self._match(TokenType.AT_TIME_ZONE): + return this + + return self.expression(exp.AtTimeZone, this=this, zone=self._parse_unary()) + + def _parse_column(self): + this = self._parse_field() + if isinstance(this, exp.Identifier): + this = self.expression(exp.Column, this=this) + elif not this: + return self._parse_bracket(this) + this = self._parse_bracket(this) + + while self._match_set(self.COLUMN_OPERATORS): + op = self.COLUMN_OPERATORS.get(self._prev.token_type) + field = self._parse_star() or self._parse_function() or self._parse_id_var() + + if isinstance(field, exp.Func): + # bigquery allows function calls like x.y.count(...) + # SAFE.SUBSTR(...) + # https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-reference#function_call_rules + this = self._replace_columns_with_dots(this) + + if op: + this = op(self, this, exp.Literal.string(field.name)) + elif isinstance(this, exp.Column) and not this.table: + this = self.expression(exp.Column, this=field, table=this.this) + else: + this = self.expression(exp.Dot, this=this, expression=field) + this = self._parse_bracket(this) + + return this + + def _parse_primary(self): + if self._match_set(self.PRIMARY_PARSERS): + return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + + if self._match(TokenType.L_PAREN): + query = self._parse_select() + + if query: + expressions = [query] + else: + expressions = self._parse_csv( + lambda: self._parse_alias(self._parse_conjunction(), explicit=True) + ) + + this = list_get(expressions, 0) + self._parse_query_modifiers(this) + self._match_r_paren() + + if isinstance(this, exp.Subqueryable): + return self._parse_subquery(this) + if len(expressions) > 1: + return self.expression(exp.Tuple, expressions=expressions) + return self.expression(exp.Paren, this=this) + + return None + + def _parse_field(self, any_token=False): + return ( + self._parse_primary() + or self._parse_function() + or self._parse_id_var(any_token) + ) + + def _parse_function(self): + if not self._curr: + return None + + token_type = self._curr.token_type + + if self._match_set(self.NO_PAREN_FUNCTION_PARSERS): + return self.NO_PAREN_FUNCTION_PARSERS[token_type](self) + + if not self._next or self._next.token_type != TokenType.L_PAREN: + if token_type in self.NO_PAREN_FUNCTIONS: + return self.expression( + self._advance() or self.NO_PAREN_FUNCTIONS[token_type] + ) + return None + + if token_type not in self.FUNC_TOKENS: + return None + + if self._match_set(self.FUNCTION_PARSERS): + self._advance() + this = self.FUNCTION_PARSERS[token_type](self, token_type) + else: + subquery_predicate = self.SUBQUERY_PREDICATES.get(token_type) + this = self._curr.text + self._advance(2) + + if subquery_predicate and self._curr.token_type in ( + TokenType.SELECT, + TokenType.WITH, + ): + this = self.expression(subquery_predicate, this=self._parse_select()) + self._match_r_paren() + return this + + function = self.FUNCTIONS.get(this.upper()) + args = self._parse_csv(self._parse_lambda) + + if function: + this = function(args) + self.validate_expression(this, args) + else: + this = self.expression(exp.Anonymous, this=this, expressions=args) + self._match_r_paren() + return self._parse_window(this) + + def _parse_lambda(self): + index = self._index + + if self._match(TokenType.L_PAREN): + expressions = self._parse_csv(self._parse_id_var) + self._match(TokenType.R_PAREN) + else: + expressions = [self._parse_id_var()] + + if not self._match(TokenType.ARROW): + self._retreat(index) + + distinct = self._match(TokenType.DISTINCT) + this = self._parse_conjunction() + + if distinct: + this = self.expression(exp.Distinct, this=this) + + if self._match(TokenType.IGNORE_NULLS): + this = self.expression(exp.IgnoreNulls, this=this) + else: + self._match(TokenType.RESPECT_NULLS) + + return self._parse_alias(self._parse_limit(self._parse_order(this))) + + return self.expression( + exp.Lambda, + this=self._parse_conjunction(), + expressions=expressions, + ) + + def _parse_schema(self, this=None): + index = self._index + if not self._match(TokenType.L_PAREN) or self._match(TokenType.SELECT): + self._retreat(index) + return this + + args = self._parse_csv( + lambda: self._parse_constraint() + or self._parse_column_def(self._parse_field()) + ) + self._match_r_paren() + return self.expression(exp.Schema, this=this, expressions=args) + + def _parse_column_def(self, this): + kind = self._parse_types() + + if not kind: + return this + + constraints = [] + while True: + constraint = self._parse_column_constraint() + if not constraint: + break + constraints.append(constraint) + + return self.expression( + exp.ColumnDef, this=this, kind=kind, constraints=constraints + ) + + def _parse_column_constraint(self): + kind = None + this = None + + if self._match(TokenType.CONSTRAINT): + this = self._parse_id_var() + + if self._match(TokenType.AUTO_INCREMENT): + kind = exp.AutoIncrementColumnConstraint() + elif self._match(TokenType.CHECK): + self._match_l_paren() + kind = self.expression( + exp.CheckColumnConstraint, this=self._parse_conjunction() + ) + self._match_r_paren() + elif self._match(TokenType.COLLATE): + kind = self.expression(exp.CollateColumnConstraint, this=self._parse_var()) + elif self._match(TokenType.DEFAULT): + kind = self.expression( + exp.DefaultColumnConstraint, this=self._parse_field() + ) + elif self._match(TokenType.NOT) and self._match(TokenType.NULL): + kind = exp.NotNullColumnConstraint() + elif self._match(TokenType.SCHEMA_COMMENT): + kind = self.expression( + exp.CommentColumnConstraint, this=self._parse_string() + ) + elif self._match(TokenType.PRIMARY_KEY): + kind = exp.PrimaryKeyColumnConstraint() + elif self._match(TokenType.UNIQUE): + kind = exp.UniqueColumnConstraint() + + if kind is None: + return None + + return self.expression(exp.ColumnConstraint, this=this, kind=kind) + + def _parse_constraint(self): + if not self._match(TokenType.CONSTRAINT): + return self._parse_unnamed_constraint() + + this = self._parse_id_var() + expressions = [] + + while True: + constraint = self._parse_unnamed_constraint() or self._parse_function() + if not constraint: + break + expressions.append(constraint) + + return self.expression(exp.Constraint, this=this, expressions=expressions) + + def _parse_unnamed_constraint(self): + if not self._match_set(self.CONSTRAINT_PARSERS): + return None + + return self.CONSTRAINT_PARSERS[self._prev.token_type](self) + + def _parse_check(self): + self._match(TokenType.CHECK) + self._match_l_paren() + expression = self._parse_conjunction() + self._match_r_paren() + + return self.expression(exp.Check, this=expression) + + def _parse_unique(self): + self._match(TokenType.UNIQUE) + columns = self._parse_wrapped_id_vars() + + return self.expression(exp.Unique, expressions=columns) + + def _parse_foreign_key(self): + self._match(TokenType.FOREIGN_KEY) + + expressions = self._parse_wrapped_id_vars() + reference = self._match(TokenType.REFERENCES) and self.expression( + exp.Reference, + this=self._parse_id_var(), + expressions=self._parse_wrapped_id_vars(), + ) + options = {} + + while self._match(TokenType.ON): + if not self._match_set((TokenType.DELETE, TokenType.UPDATE)): + self.raise_error("Expected DELETE or UPDATE") + kind = self._prev.text.lower() + + if self._match(TokenType.NO_ACTION): + action = "NO ACTION" + elif self._match(TokenType.SET): + self._match_set((TokenType.NULL, TokenType.DEFAULT)) + action = "SET " + self._prev.text.upper() + else: + self._advance() + action = self._prev.text.upper() + options[kind] = action + + return self.expression( + exp.ForeignKey, + expressions=expressions, + reference=reference, + **options, + ) + + def _parse_bracket(self, this): + if not self._match(TokenType.L_BRACKET): + return this + + expressions = self._parse_csv(self._parse_conjunction) + + if not this or this.name.upper() == "ARRAY": + this = self.expression(exp.Array, expressions=expressions) + else: + expressions = apply_index_offset(expressions, -self.index_offset) + this = self.expression(exp.Bracket, this=this, expressions=expressions) + + if not self._match(TokenType.R_BRACKET): + self.raise_error("Expected ]") + + return self._parse_bracket(this) + + def _parse_case(self): + ifs = [] + default = None + + expression = self._parse_conjunction() + + while self._match(TokenType.WHEN): + this = self._parse_conjunction() + self._match(TokenType.THEN) + then = self._parse_conjunction() + ifs.append(self.expression(exp.If, this=this, true=then)) + + if self._match(TokenType.ELSE): + default = self._parse_conjunction() + + if not self._match(TokenType.END): + self.raise_error("Expected END after CASE", self._prev) + + return self._parse_window( + self.expression(exp.Case, this=expression, ifs=ifs, default=default) + ) + + def _parse_if(self): + if self._match(TokenType.L_PAREN): + args = self._parse_csv(self._parse_conjunction) + this = exp.If.from_arg_list(args) + self.validate_expression(this, args) + self._match_r_paren() + else: + condition = self._parse_conjunction() + self._match(TokenType.THEN) + true = self._parse_conjunction() + false = self._parse_conjunction() if self._match(TokenType.ELSE) else None + self._match(TokenType.END) + this = self.expression(exp.If, this=condition, true=true, false=false) + return self._parse_window(this) + + def _parse_extract(self): + this = self._parse_var() or self._parse_type() + + if not self._match(TokenType.FROM): + self.raise_error("Expected FROM after EXTRACT", self._prev) + + return self.expression(exp.Extract, this=this, expression=self._parse_type()) + + def _parse_cast(self, strict): + this = self._parse_conjunction() + + if not self._match(TokenType.ALIAS): + self.raise_error("Expected AS after CAST") + + to = self._parse_types() + + if not to: + self.raise_error("Expected TYPE after CAST") + elif to.this == exp.DataType.Type.CHAR: + if self._match(TokenType.CHARACTER_SET): + to = self.expression(exp.CharacterSet, this=self._parse_var_or_string()) + + return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) + + def _parse_convert(self): + this = self._parse_field() + if self._match(TokenType.USING): + to = self.expression(exp.CharacterSet, this=self._parse_var()) + elif self._match(TokenType.COMMA): + to = self._parse_types() + else: + to = None + return self.expression(exp.Cast, this=this, to=to) + + def _parse_window(self, this, alias=False): + if self._match(TokenType.FILTER): + self._match_l_paren() + this = self.expression( + exp.Filter, this=this, expression=self._parse_where() + ) + self._match_r_paren() + + if self._match(TokenType.WITHIN_GROUP): + self._match_l_paren() + this = self.expression( + exp.WithinGroup, + this=this, + expression=self._parse_order(), + ) + self._match_r_paren() + return this + + # bigquery select from window x AS (partition by ...) + if alias: + self._match(TokenType.ALIAS) + elif not self._match(TokenType.OVER): + return this + + if not self._match(TokenType.L_PAREN): + alias = self._parse_id_var(False) + + return self.expression( + exp.Window, + this=this, + alias=alias, + ) + + partition = None + + alias = self._parse_id_var(False) + + if self._match(TokenType.PARTITION_BY): + partition = self._parse_csv(self._parse_conjunction) + + order = self._parse_order() + + spec = None + kind = self._match_set((TokenType.ROWS, TokenType.RANGE)) and self._prev.text + + if kind: + self._match(TokenType.BETWEEN) + start = self._parse_window_spec() + self._match(TokenType.AND) + end = self._parse_window_spec() + + spec = self.expression( + exp.WindowSpec, + kind=kind, + start=start["value"], + start_side=start["side"], + end=end["value"], + end_side=end["side"], + ) + + self._match_r_paren() + + return self.expression( + exp.Window, + this=this, + partition_by=partition, + order=order, + spec=spec, + alias=alias, + ) + + def _parse_window_spec(self): + self._match(TokenType.BETWEEN) + + return { + "value": ( + self._match_set((TokenType.UNBOUNDED, TokenType.CURRENT_ROW)) + and self._prev.text + ) + or self._parse_bitwise(), + "side": self._match_set((TokenType.PRECEDING, TokenType.FOLLOWING)) + and self._prev.text, + } + + def _parse_alias(self, this, explicit=False): + any_token = self._match(TokenType.ALIAS) + + if explicit and not any_token: + return this + + if self._match(TokenType.L_PAREN): + aliases = self.expression( + exp.Aliases, + this=this, + expressions=self._parse_csv(lambda: self._parse_id_var(any_token)), + ) + self._match_r_paren() + return aliases + + alias = self._parse_id_var(any_token) + + if alias: + return self.expression(exp.Alias, this=this, alias=alias) + + return this + + def _parse_id_var(self, any_token=True): + identifier = self._parse_identifier() + + if identifier: + return identifier + + if ( + any_token + and self._curr + and self._curr.token_type not in self.RESERVED_KEYWORDS + ): + return self._advance() or exp.Identifier(this=self._prev.text, quoted=False) + + return self._match_set(self.ID_VAR_TOKENS) and exp.Identifier( + this=self._prev.text, quoted=False + ) + + def _parse_string(self): + if self._match(TokenType.STRING): + return exp.Literal.string(self._prev.text) + return self._parse_placeholder() + + def _parse_number(self): + if self._match(TokenType.NUMBER): + return exp.Literal.number(self._prev.text) + return self._parse_placeholder() + + def _parse_identifier(self): + if self._match(TokenType.IDENTIFIER): + return exp.Identifier(this=self._prev.text, quoted=True) + return self._parse_placeholder() + + def _parse_var(self): + if self._match(TokenType.VAR): + return exp.Var(this=self._prev.text) + return self._parse_placeholder() + + def _parse_var_or_string(self): + return self._parse_var() or self._parse_string() + + def _parse_null(self): + if self._match(TokenType.NULL): + return exp.Null() + return None + + def _parse_boolean(self): + if self._match(TokenType.TRUE): + return exp.Boolean(this=True) + if self._match(TokenType.FALSE): + return exp.Boolean(this=False) + return None + + def _parse_star(self): + if self._match(TokenType.STAR): + return exp.Star( + **{"except": self._parse_except(), "replace": self._parse_replace()} + ) + return None + + def _parse_placeholder(self): + if self._match(TokenType.PLACEHOLDER): + return exp.Placeholder() + return None + + def _parse_except(self): + if not self._match(TokenType.EXCEPT): + return None + + return self._parse_wrapped_id_vars() + + def _parse_replace(self): + if not self._match(TokenType.REPLACE): + return None + + self._match_l_paren() + columns = self._parse_csv(lambda: self._parse_alias(self._parse_expression())) + self._match_r_paren() + return columns + + def _parse_csv(self, parse): + parse_result = parse() + items = [parse_result] if parse_result is not None else [] + + while self._match(TokenType.COMMA): + parse_result = parse() + if parse_result is not None: + items.append(parse_result) + + return items + + def _parse_tokens(self, parse, expressions): + this = parse() + + while self._match_set(expressions): + this = self.expression( + expressions[self._prev.token_type], this=this, expression=parse() + ) + + return this + + def _parse_all(self, parse): + return list(iter(parse, None)) + + def _parse_wrapped_id_vars(self): + self._match_l_paren() + expressions = self._parse_csv(self._parse_id_var) + self._match_r_paren() + return expressions + + def _match(self, token_type): + if not self._curr: + return None + + if self._curr.token_type == token_type: + self._advance() + return True + + return None + + def _match_set(self, types): + if not self._curr: + return None + + if self._curr.token_type in types: + self._advance() + return True + + return None + + def _match_pair(self, token_type_a, token_type_b, advance=True): + if not self._curr or not self._next: + return None + + if ( + self._curr.token_type == token_type_a + and self._next.token_type == token_type_b + ): + if advance: + self._advance(2) + return True + + return None + + def _match_l_paren(self): + if not self._match(TokenType.L_PAREN): + self.raise_error("Expecting (") + + def _match_r_paren(self): + if not self._match(TokenType.R_PAREN): + self.raise_error("Expecting )") + + def _replace_columns_with_dots(self, this): + if isinstance(this, exp.Dot): + exp.replace_children(this, self._replace_columns_with_dots) + elif isinstance(this, exp.Column): + exp.replace_children(this, self._replace_columns_with_dots) + table = this.args.get("table") + this = ( + self.expression(exp.Dot, this=table, expression=this.this) + if table + else self.expression(exp.Var, this=this.name) + ) + elif isinstance(this, exp.Identifier): + this = self.expression(exp.Var, this=this.name) + return this diff --git a/sqlglot/planner.py b/sqlglot/planner.py new file mode 100644 index 0000000..2006a75 --- /dev/null +++ b/sqlglot/planner.py @@ -0,0 +1,340 @@ +import itertools +import math + +from sqlglot import alias, exp +from sqlglot.errors import UnsupportedError +from sqlglot.optimizer.simplify import simplify + + +class Plan: + def __init__(self, expression): + self.expression = expression + self.root = Step.from_expression(self.expression) + self._dag = {} + + @property + def dag(self): + if not self._dag: + dag = {} + nodes = {self.root} + + while nodes: + node = nodes.pop() + dag[node] = set() + for dep in node.dependencies: + dag[node].add(dep) + nodes.add(dep) + self._dag = dag + + return self._dag + + @property + def leaves(self): + return (node for node, deps in self.dag.items() if not deps) + + +class Step: + @classmethod + def from_expression(cls, expression, ctes=None): + """ + Build a DAG of Steps from a SQL expression. + + Giving an expression like: + + SELECT x.a, SUM(x.b) + FROM x + JOIN y + ON x.a = y.a + GROUP BY x.a + + Transform it into a DAG of the form: + + Aggregate(x.a, SUM(x.b)) + Join(y) + Scan(x) + Scan(y) + + This can then more easily be executed on by an engine. + """ + ctes = ctes or {} + with_ = expression.args.get("with") + + # CTEs break the mold of scope and introduce themselves to all in the context. + if with_: + ctes = ctes.copy() + for cte in with_.expressions: + step = Step.from_expression(cte.this, ctes) + step.name = cte.alias + ctes[step.name] = step + + from_ = expression.args.get("from") + + if from_: + from_ = from_.expressions + if len(from_) > 1: + raise UnsupportedError( + "Multi-from statements are unsupported. Run it through the optimizer" + ) + + step = Scan.from_expression(from_[0], ctes) + else: + raise UnsupportedError("Static selects are unsupported.") + + joins = expression.args.get("joins") + + if joins: + join = Join.from_joins(joins, ctes) + join.name = step.name + join.add_dependency(step) + step = join + + projections = [] # final selects in this chain of steps representing a select + operands = {} # intermediate computations of agg funcs eg x + 1 in SUM(x + 1) + aggregations = [] + sequence = itertools.count() + + for e in expression.expressions: + aggregation = e.find(exp.AggFunc) + + if aggregation: + projections.append(exp.column(e.alias_or_name, step.name, quoted=True)) + aggregations.append(e) + for operand in aggregation.unnest_operands(): + if isinstance(operand, exp.Column): + continue + if operand not in operands: + operands[operand] = f"_a_{next(sequence)}" + operand.replace( + exp.column(operands[operand], step.name, quoted=True) + ) + else: + projections.append(e) + + where = expression.args.get("where") + + if where: + step.condition = where.this + + group = expression.args.get("group") + + if group: + aggregate = Aggregate() + aggregate.source = step.name + aggregate.name = step.name + aggregate.operands = tuple( + alias(operand, alias_) for operand, alias_ in operands.items() + ) + aggregate.aggregations = aggregations + aggregate.group = [ + exp.column(e.alias_or_name, step.name, quoted=True) + for e in group.expressions + ] + aggregate.add_dependency(step) + step = aggregate + + having = expression.args.get("having") + + if having: + step.condition = having.this + + order = expression.args.get("order") + + if order: + sort = Sort() + sort.name = step.name + sort.key = order.expressions + sort.add_dependency(step) + step = sort + for k in sort.key + projections: + for column in k.find_all(exp.Column): + column.set("table", exp.to_identifier(step.name, quoted=True)) + + step.projections = projections + + limit = expression.args.get("limit") + + if limit: + step.limit = int(limit.text("expression")) + + return step + + def __init__(self): + self.name = None + self.dependencies = set() + self.dependents = set() + self.projections = [] + self.limit = math.inf + self.condition = None + + def add_dependency(self, dependency): + self.dependencies.add(dependency) + dependency.dependents.add(self) + + def __repr__(self): + return self.to_s() + + def to_s(self, level=0): + indent = " " * level + nested = f"{indent} " + + context = self._to_s(f"{nested} ") + + if context: + context = [f"{nested}Context:"] + context + + lines = [ + f"{indent}- {self.__class__.__name__}: {self.name}", + *context, + f"{nested}Projections:", + ] + + for expression in self.projections: + lines.append(f"{nested} - {expression.sql()}") + + if self.condition: + lines.append(f"{nested}Condition: {self.condition.sql()}") + + if self.dependencies: + lines.append(f"{nested}Dependencies:") + for dependency in self.dependencies: + lines.append(" " + dependency.to_s(level + 1)) + + return "\n".join(lines) + + def _to_s(self, _indent): + return [] + + +class Scan(Step): + @classmethod + def from_expression(cls, expression, ctes=None): + table = expression.this + alias_ = expression.alias + + if not alias_: + raise UnsupportedError( + "Tables/Subqueries must be aliased. Run it through the optimizer" + ) + + if isinstance(expression, exp.Subquery): + step = Step.from_expression(table, ctes) + step.name = alias_ + return step + + step = Scan() + step.name = alias_ + step.source = expression + if table.name in ctes: + step.add_dependency(ctes[table.name]) + + return step + + def __init__(self): + super().__init__() + self.source = None + + def _to_s(self, indent): + return [f"{indent}Source: {self.source.sql()}"] + + +class Write(Step): + pass + + +class Join(Step): + @classmethod + def from_joins(cls, joins, ctes=None): + step = Join() + + for join in joins: + name = join.this.alias + on = join.args.get("on") or exp.TRUE + source_key = [] + join_key = [] + + # find the join keys + # SELECT + # FROM x + # JOIN y + # ON x.a = y.b AND y.b > 1 + # + # should pull y.b as the join key and x.a as the source key + for condition in on.flatten() if isinstance(on, exp.And) else [on]: + if isinstance(condition, exp.EQ): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.TRUE) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.TRUE) + + on = simplify(on) + + step.joins[name] = { + "side": join.side, + "join_key": join_key, + "source_key": source_key, + "condition": None if on == exp.TRUE else on, + } + + step.add_dependency(Scan.from_expression(join.this, ctes)) + + return step + + def __init__(self): + super().__init__() + self.joins = {} + + def _to_s(self, indent): + lines = [] + for name, join in self.joins.items(): + lines.append(f"{indent}{name}: {join['side']}") + if join.get("condition"): + lines.append(f"{indent}On: {join['condition'].sql()}") + return lines + + +class Aggregate(Step): + def __init__(self): + super().__init__() + self.aggregations = [] + self.operands = [] + self.group = [] + self.source = None + + def _to_s(self, indent): + lines = [f"{indent}Aggregations:"] + + for expression in self.aggregations: + lines.append(f"{indent} - {expression.sql()}") + + if self.group: + lines.append(f"{indent}Group:") + for expression in self.group: + lines.append(f"{indent} - {expression.sql()}") + if self.operands: + lines.append(f"{indent}Operands:") + for expression in self.operands: + lines.append(f"{indent} - {expression.sql()}") + + return lines + + +class Sort(Step): + def __init__(self): + super().__init__() + self.key = None + + def _to_s(self, indent): + lines = [f"{indent}Key:"] + + for expression in self.key: + lines.append(f"{indent} - {expression.sql()}") + + return lines diff --git a/sqlglot/time.py b/sqlglot/time.py new file mode 100644 index 0000000..16314c5 --- /dev/null +++ b/sqlglot/time.py @@ -0,0 +1,45 @@ +# the generic time format is based on python time.strftime +# https://docs.python.org/3/library/time.html#time.strftime +from sqlglot.trie import in_trie, new_trie + + +def format_time(string, mapping, trie=None): + """ + Converts a time string given a mapping. + + Examples: + >>> format_time("%Y", {"%Y": "YYYY"}) + 'YYYY' + + mapping: Dictionary of time format to target time format + trie: Optional trie, can be passed in for performance + """ + start = 0 + end = 1 + size = len(string) + trie = trie or new_trie(mapping) + current = trie + chunks = [] + sym = None + + while end <= size: + chars = string[start:end] + result, current = in_trie(current, chars[-1]) + + if result == 0: + if sym: + end -= 1 + chars = sym + sym = None + start += len(chars) + chunks.append(chars) + current = trie + elif result == 2: + sym = chars + + end += 1 + + if result and end > size: + chunks.append(chars) + + return "".join(mapping.get(chars, chars) for chars in chunks) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py new file mode 100644 index 0000000..e4b754d --- /dev/null +++ b/sqlglot/tokens.py @@ -0,0 +1,853 @@ +from enum import auto + +from sqlglot.helper import AutoName +from sqlglot.trie import in_trie, new_trie + + +class TokenType(AutoName): + L_PAREN = auto() + R_PAREN = auto() + L_BRACKET = auto() + R_BRACKET = auto() + L_BRACE = auto() + R_BRACE = auto() + COMMA = auto() + DOT = auto() + DASH = auto() + PLUS = auto() + COLON = auto() + DCOLON = auto() + SEMICOLON = auto() + STAR = auto() + SLASH = auto() + LT = auto() + LTE = auto() + GT = auto() + GTE = auto() + NOT = auto() + EQ = auto() + NEQ = auto() + AND = auto() + OR = auto() + AMP = auto() + DPIPE = auto() + PIPE = auto() + CARET = auto() + TILDA = auto() + ARROW = auto() + DARROW = auto() + HASH_ARROW = auto() + DHASH_ARROW = auto() + ANNOTATION = auto() + DOLLAR = auto() + + SPACE = auto() + BREAK = auto() + + STRING = auto() + NUMBER = auto() + IDENTIFIER = auto() + COLUMN = auto() + COLUMN_DEF = auto() + SCHEMA = auto() + TABLE = auto() + VAR = auto() + BIT_STRING = auto() + + # types + BOOLEAN = auto() + TINYINT = auto() + SMALLINT = auto() + INT = auto() + BIGINT = auto() + FLOAT = auto() + DOUBLE = auto() + DECIMAL = auto() + CHAR = auto() + NCHAR = auto() + VARCHAR = auto() + NVARCHAR = auto() + TEXT = auto() + BINARY = auto() + BYTEA = auto() + JSON = auto() + TIMESTAMP = auto() + TIMESTAMPTZ = auto() + DATETIME = auto() + DATE = auto() + UUID = auto() + GEOGRAPHY = auto() + NULLABLE = auto() + + # keywords + ADD_FILE = auto() + ALIAS = auto() + ALL = auto() + ALTER = auto() + ANALYZE = auto() + ANY = auto() + ARRAY = auto() + ASC = auto() + AT_TIME_ZONE = auto() + AUTO_INCREMENT = auto() + BEGIN = auto() + BETWEEN = auto() + BUCKET = auto() + CACHE = auto() + CALL = auto() + CASE = auto() + CAST = auto() + CHARACTER_SET = auto() + CHECK = auto() + CLUSTER_BY = auto() + COLLATE = auto() + COMMENT = auto() + COMMIT = auto() + CONSTRAINT = auto() + CONVERT = auto() + CREATE = auto() + CROSS = auto() + CUBE = auto() + CURRENT_DATE = auto() + CURRENT_DATETIME = auto() + CURRENT_ROW = auto() + CURRENT_TIME = auto() + CURRENT_TIMESTAMP = auto() + DIV = auto() + DEFAULT = auto() + DELETE = auto() + DESC = auto() + DISTINCT = auto() + DISTRIBUTE_BY = auto() + DROP = auto() + ELSE = auto() + END = auto() + ENGINE = auto() + ESCAPE = auto() + EXCEPT = auto() + EXISTS = auto() + EXPLAIN = auto() + EXTRACT = auto() + FALSE = auto() + FETCH = auto() + FILTER = auto() + FINAL = auto() + FIRST = auto() + FOLLOWING = auto() + FOREIGN_KEY = auto() + FORMAT = auto() + FULL = auto() + FUNCTION = auto() + FROM = auto() + GROUP_BY = auto() + GROUPING_SETS = auto() + HAVING = auto() + HINT = auto() + IF = auto() + IGNORE_NULLS = auto() + ILIKE = auto() + IN = auto() + INDEX = auto() + INNER = auto() + INSERT = auto() + INTERSECT = auto() + INTERVAL = auto() + INTO = auto() + INTRODUCER = auto() + IS = auto() + ISNULL = auto() + JOIN = auto() + LATERAL = auto() + LAZY = auto() + LEFT = auto() + LIKE = auto() + LIMIT = auto() + LOCATION = auto() + MAP = auto() + MOD = auto() + NEXT = auto() + NO_ACTION = auto() + NULL = auto() + NULLS_FIRST = auto() + NULLS_LAST = auto() + OFFSET = auto() + ON = auto() + ONLY = auto() + OPTIMIZE = auto() + OPTIONS = auto() + ORDER_BY = auto() + ORDERED = auto() + ORDINALITY = auto() + OUTER = auto() + OUT_OF = auto() + OVER = auto() + OVERWRITE = auto() + PARTITION = auto() + PARTITION_BY = auto() + PARTITIONED_BY = auto() + PERCENT = auto() + PLACEHOLDER = auto() + PRECEDING = auto() + PRIMARY_KEY = auto() + PROPERTIES = auto() + QUALIFY = auto() + QUOTE = auto() + RANGE = auto() + RECURSIVE = auto() + REPLACE = auto() + RESPECT_NULLS = auto() + REFERENCES = auto() + RIGHT = auto() + RLIKE = auto() + ROLLUP = auto() + ROW = auto() + ROWS = auto() + SCHEMA_COMMENT = auto() + SELECT = auto() + SET = auto() + SHOW = auto() + SOME = auto() + SORT_BY = auto() + STORED = auto() + STRUCT = auto() + TABLE_FORMAT = auto() + TABLE_SAMPLE = auto() + TEMPORARY = auto() + TIME = auto() + TOP = auto() + THEN = auto() + TRUE = auto() + TRUNCATE = auto() + TRY_CAST = auto() + UNBOUNDED = auto() + UNCACHE = auto() + UNION = auto() + UNNEST = auto() + UPDATE = auto() + USE = auto() + USING = auto() + VALUES = auto() + VIEW = auto() + WHEN = auto() + WHERE = auto() + WINDOW = auto() + WITH = auto() + WITH_TIME_ZONE = auto() + WITHIN_GROUP = auto() + WITHOUT_TIME_ZONE = auto() + UNIQUE = auto() + + +class Token: + __slots__ = ("token_type", "text", "line", "col") + + @classmethod + def number(cls, number): + return cls(TokenType.NUMBER, str(number)) + + @classmethod + def string(cls, string): + return cls(TokenType.STRING, string) + + @classmethod + def identifier(cls, identifier): + return cls(TokenType.IDENTIFIER, identifier) + + @classmethod + def var(cls, var): + return cls(TokenType.VAR, var) + + def __init__(self, token_type, text, line=1, col=1): + self.token_type = token_type + self.text = text + self.line = line + self.col = max(col - len(text), 1) + + def __repr__(self): + attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) + return f"<Token {attributes}>" + + +class _Tokenizer(type): + def __new__(cls, clsname, bases, attrs): + klass = super().__new__(cls, clsname, bases, attrs) + + klass.QUOTES = dict( + (quote, quote) if isinstance(quote, str) else (quote[0], quote[1]) + for quote in klass.QUOTES + ) + + klass.IDENTIFIERS = dict( + (identifier, identifier) + if isinstance(identifier, str) + else (identifier[0], identifier[1]) + for identifier in klass.IDENTIFIERS + ) + + klass.COMMENTS = dict( + (comment, None) if isinstance(comment, str) else (comment[0], comment[1]) + for comment in klass.COMMENTS + ) + + klass.KEYWORD_TRIE = new_trie( + key.upper() + for key, value in { + **klass.KEYWORDS, + **{comment: TokenType.COMMENT for comment in klass.COMMENTS}, + **{quote: TokenType.QUOTE for quote in klass.QUOTES}, + }.items() + if " " in key or any(single in key for single in klass.SINGLE_TOKENS) + ) + + return klass + + +class Tokenizer(metaclass=_Tokenizer): + SINGLE_TOKENS = { + "(": TokenType.L_PAREN, + ")": TokenType.R_PAREN, + "[": TokenType.L_BRACKET, + "]": TokenType.R_BRACKET, + "{": TokenType.L_BRACE, + "}": TokenType.R_BRACE, + "&": TokenType.AMP, + "^": TokenType.CARET, + ":": TokenType.COLON, + ",": TokenType.COMMA, + ".": TokenType.DOT, + "-": TokenType.DASH, + "=": TokenType.EQ, + ">": TokenType.GT, + "<": TokenType.LT, + "%": TokenType.MOD, + "!": TokenType.NOT, + "|": TokenType.PIPE, + "+": TokenType.PLUS, + ";": TokenType.SEMICOLON, + "/": TokenType.SLASH, + "*": TokenType.STAR, + "~": TokenType.TILDA, + "?": TokenType.PLACEHOLDER, + "#": TokenType.ANNOTATION, + "$": TokenType.DOLLAR, + # used for breaking a var like x'y' but nothing else + # the token type doesn't matter + "'": TokenType.QUOTE, + "`": TokenType.IDENTIFIER, + '"': TokenType.IDENTIFIER, + } + + QUOTES = ["'"] + + IDENTIFIERS = ['"'] + + ESCAPE = "'" + + KEYWORDS = { + "/*+": TokenType.HINT, + "*/": TokenType.HINT, + "==": TokenType.EQ, + "::": TokenType.DCOLON, + "||": TokenType.DPIPE, + ">=": TokenType.GTE, + "<=": TokenType.LTE, + "<>": TokenType.NEQ, + "!=": TokenType.NEQ, + "->": TokenType.ARROW, + "->>": TokenType.DARROW, + "#>": TokenType.HASH_ARROW, + "#>>": TokenType.DHASH_ARROW, + "ADD ARCHIVE": TokenType.ADD_FILE, + "ADD ARCHIVES": TokenType.ADD_FILE, + "ADD FILE": TokenType.ADD_FILE, + "ADD FILES": TokenType.ADD_FILE, + "ADD JAR": TokenType.ADD_FILE, + "ADD JARS": TokenType.ADD_FILE, + "ALL": TokenType.ALL, + "ALTER": TokenType.ALTER, + "ANALYZE": TokenType.ANALYZE, + "AND": TokenType.AND, + "ANY": TokenType.ANY, + "ASC": TokenType.ASC, + "AS": TokenType.ALIAS, + "AT TIME ZONE": TokenType.AT_TIME_ZONE, + "AUTO_INCREMENT": TokenType.AUTO_INCREMENT, + "BEGIN": TokenType.BEGIN, + "BETWEEN": TokenType.BETWEEN, + "BUCKET": TokenType.BUCKET, + "CALL": TokenType.CALL, + "CACHE": TokenType.CACHE, + "UNCACHE": TokenType.UNCACHE, + "CASE": TokenType.CASE, + "CAST": TokenType.CAST, + "CHARACTER SET": TokenType.CHARACTER_SET, + "CHECK": TokenType.CHECK, + "CLUSTER BY": TokenType.CLUSTER_BY, + "COLLATE": TokenType.COLLATE, + "COMMENT": TokenType.SCHEMA_COMMENT, + "COMMIT": TokenType.COMMIT, + "CONSTRAINT": TokenType.CONSTRAINT, + "CONVERT": TokenType.CONVERT, + "CREATE": TokenType.CREATE, + "CROSS": TokenType.CROSS, + "CUBE": TokenType.CUBE, + "CURRENT_DATE": TokenType.CURRENT_DATE, + "CURRENT ROW": TokenType.CURRENT_ROW, + "CURRENT_TIMESTAMP": TokenType.CURRENT_TIMESTAMP, + "DIV": TokenType.DIV, + "DEFAULT": TokenType.DEFAULT, + "DELETE": TokenType.DELETE, + "DESC": TokenType.DESC, + "DISTINCT": TokenType.DISTINCT, + "DISTRIBUTE BY": TokenType.DISTRIBUTE_BY, + "DROP": TokenType.DROP, + "ELSE": TokenType.ELSE, + "END": TokenType.END, + "ENGINE": TokenType.ENGINE, + "ESCAPE": TokenType.ESCAPE, + "EXCEPT": TokenType.EXCEPT, + "EXISTS": TokenType.EXISTS, + "EXPLAIN": TokenType.EXPLAIN, + "EXTRACT": TokenType.EXTRACT, + "FALSE": TokenType.FALSE, + "FETCH": TokenType.FETCH, + "FILTER": TokenType.FILTER, + "FIRST": TokenType.FIRST, + "FULL": TokenType.FULL, + "FUNCTION": TokenType.FUNCTION, + "FOLLOWING": TokenType.FOLLOWING, + "FOREIGN KEY": TokenType.FOREIGN_KEY, + "FORMAT": TokenType.FORMAT, + "FROM": TokenType.FROM, + "GROUP BY": TokenType.GROUP_BY, + "GROUPING SETS": TokenType.GROUPING_SETS, + "HAVING": TokenType.HAVING, + "IF": TokenType.IF, + "ILIKE": TokenType.ILIKE, + "IGNORE NULLS": TokenType.IGNORE_NULLS, + "IN": TokenType.IN, + "INDEX": TokenType.INDEX, + "INNER": TokenType.INNER, + "INSERT": TokenType.INSERT, + "INTERVAL": TokenType.INTERVAL, + "INTERSECT": TokenType.INTERSECT, + "INTO": TokenType.INTO, + "IS": TokenType.IS, + "ISNULL": TokenType.ISNULL, + "JOIN": TokenType.JOIN, + "LATERAL": TokenType.LATERAL, + "LAZY": TokenType.LAZY, + "LEFT": TokenType.LEFT, + "LIKE": TokenType.LIKE, + "LIMIT": TokenType.LIMIT, + "LOCATION": TokenType.LOCATION, + "NEXT": TokenType.NEXT, + "NO ACTION": TokenType.NO_ACTION, + "NOT": TokenType.NOT, + "NULL": TokenType.NULL, + "NULLS FIRST": TokenType.NULLS_FIRST, + "NULLS LAST": TokenType.NULLS_LAST, + "OFFSET": TokenType.OFFSET, + "ON": TokenType.ON, + "ONLY": TokenType.ONLY, + "OPTIMIZE": TokenType.OPTIMIZE, + "OPTIONS": TokenType.OPTIONS, + "OR": TokenType.OR, + "ORDER BY": TokenType.ORDER_BY, + "ORDINALITY": TokenType.ORDINALITY, + "OUTER": TokenType.OUTER, + "OUT OF": TokenType.OUT_OF, + "OVER": TokenType.OVER, + "OVERWRITE": TokenType.OVERWRITE, + "PARTITION": TokenType.PARTITION, + "PARTITION BY": TokenType.PARTITION_BY, + "PARTITIONED BY": TokenType.PARTITIONED_BY, + "PERCENT": TokenType.PERCENT, + "PRECEDING": TokenType.PRECEDING, + "PRIMARY KEY": TokenType.PRIMARY_KEY, + "RANGE": TokenType.RANGE, + "RECURSIVE": TokenType.RECURSIVE, + "REGEXP": TokenType.RLIKE, + "REPLACE": TokenType.REPLACE, + "RESPECT NULLS": TokenType.RESPECT_NULLS, + "REFERENCES": TokenType.REFERENCES, + "RIGHT": TokenType.RIGHT, + "RLIKE": TokenType.RLIKE, + "ROLLUP": TokenType.ROLLUP, + "ROW": TokenType.ROW, + "ROWS": TokenType.ROWS, + "SELECT": TokenType.SELECT, + "SET": TokenType.SET, + "SHOW": TokenType.SHOW, + "SOME": TokenType.SOME, + "SORT BY": TokenType.SORT_BY, + "STORED": TokenType.STORED, + "TABLE": TokenType.TABLE, + "TABLE_FORMAT": TokenType.TABLE_FORMAT, + "TBLPROPERTIES": TokenType.PROPERTIES, + "TABLESAMPLE": TokenType.TABLE_SAMPLE, + "TEMP": TokenType.TEMPORARY, + "TEMPORARY": TokenType.TEMPORARY, + "THEN": TokenType.THEN, + "TRUE": TokenType.TRUE, + "TRUNCATE": TokenType.TRUNCATE, + "TRY_CAST": TokenType.TRY_CAST, + "UNBOUNDED": TokenType.UNBOUNDED, + "UNION": TokenType.UNION, + "UNNEST": TokenType.UNNEST, + "UPDATE": TokenType.UPDATE, + "USE": TokenType.USE, + "USING": TokenType.USING, + "VALUES": TokenType.VALUES, + "VIEW": TokenType.VIEW, + "WHEN": TokenType.WHEN, + "WHERE": TokenType.WHERE, + "WITH": TokenType.WITH, + "WITH TIME ZONE": TokenType.WITH_TIME_ZONE, + "WITHIN GROUP": TokenType.WITHIN_GROUP, + "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, + "ARRAY": TokenType.ARRAY, + "BOOL": TokenType.BOOLEAN, + "BOOLEAN": TokenType.BOOLEAN, + "BYTE": TokenType.TINYINT, + "TINYINT": TokenType.TINYINT, + "SHORT": TokenType.SMALLINT, + "SMALLINT": TokenType.SMALLINT, + "INT2": TokenType.SMALLINT, + "INTEGER": TokenType.INT, + "INT": TokenType.INT, + "INT4": TokenType.INT, + "LONG": TokenType.BIGINT, + "BIGINT": TokenType.BIGINT, + "INT8": TokenType.BIGINT, + "DECIMAL": TokenType.DECIMAL, + "MAP": TokenType.MAP, + "NUMBER": TokenType.DECIMAL, + "NUMERIC": TokenType.DECIMAL, + "FIXED": TokenType.DECIMAL, + "REAL": TokenType.FLOAT, + "FLOAT": TokenType.FLOAT, + "FLOAT4": TokenType.FLOAT, + "FLOAT8": TokenType.DOUBLE, + "DOUBLE": TokenType.DOUBLE, + "JSON": TokenType.JSON, + "CHAR": TokenType.CHAR, + "NCHAR": TokenType.NCHAR, + "VARCHAR": TokenType.VARCHAR, + "VARCHAR2": TokenType.VARCHAR, + "NVARCHAR": TokenType.NVARCHAR, + "NVARCHAR2": TokenType.NVARCHAR, + "STRING": TokenType.TEXT, + "TEXT": TokenType.TEXT, + "CLOB": TokenType.TEXT, + "BINARY": TokenType.BINARY, + "BLOB": TokenType.BINARY, + "BYTEA": TokenType.BINARY, + "TIMESTAMP": TokenType.TIMESTAMP, + "TIMESTAMPTZ": TokenType.TIMESTAMPTZ, + "DATE": TokenType.DATE, + "DATETIME": TokenType.DATETIME, + "UNIQUE": TokenType.UNIQUE, + "STRUCT": TokenType.STRUCT, + } + + WHITE_SPACE = { + " ": TokenType.SPACE, + "\t": TokenType.SPACE, + "\n": TokenType.BREAK, + "\r": TokenType.BREAK, + "\r\n": TokenType.BREAK, + } + + COMMANDS = { + TokenType.ALTER, + TokenType.ADD_FILE, + TokenType.ANALYZE, + TokenType.BEGIN, + TokenType.CALL, + TokenType.COMMIT, + TokenType.EXPLAIN, + TokenType.OPTIMIZE, + TokenType.SET, + TokenType.SHOW, + TokenType.TRUNCATE, + TokenType.USE, + } + + # handle numeric literals like in hive (3L = BIGINT) + NUMERIC_LITERALS = {} + ENCODE = None + + COMMENTS = ["--", ("/*", "*/")] + KEYWORD_TRIE = None # autofilled + + __slots__ = ( + "sql", + "size", + "tokens", + "_start", + "_current", + "_line", + "_col", + "_char", + "_end", + "_peek", + ) + + def __init__(self): + """ + Tokenizer consumes a sql string and produces an array of :class:`~sqlglot.tokens.Token` + """ + self.reset() + + def reset(self): + self.sql = "" + self.size = 0 + self.tokens = [] + self._start = 0 + self._current = 0 + self._line = 1 + self._col = 1 + + self._char = None + self._end = None + self._peek = None + + def tokenize(self, sql): + self.reset() + self.sql = sql + self.size = len(sql) + + while self.size and not self._end: + self._start = self._current + self._advance() + + if not self._char: + break + + white_space = self.WHITE_SPACE.get(self._char) + identifier_end = self.IDENTIFIERS.get(self._char) + + if white_space: + if white_space == TokenType.BREAK: + self._col = 1 + self._line += 1 + elif self._char == "0" and self._peek == "x": + self._scan_hex() + elif self._char.isdigit(): + self._scan_number() + elif identifier_end: + self._scan_identifier(identifier_end) + else: + self._scan_keywords() + return self.tokens + + def _chars(self, size): + if size == 1: + return self._char + start = self._current - 1 + end = start + size + if end <= self.size: + return self.sql[start:end] + return "" + + def _advance(self, i=1): + self._col += i + self._current += i + self._end = self._current >= self.size + self._char = self.sql[self._current - 1] + self._peek = self.sql[self._current] if self._current < self.size else "" + + @property + def _text(self): + return self.sql[self._start : self._current] + + def _add(self, token_type, text=None): + text = self._text if text is None else text + self.tokens.append(Token(token_type, text, self._line, self._col)) + + if token_type in self.COMMANDS and ( + len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON + ): + self._start = self._current + while not self._end and self._peek != ";": + self._advance() + if self._start < self._current: + self._add(TokenType.STRING) + + def _scan_keywords(self): + size = 0 + word = None + chars = self._text + char = chars + prev_space = False + skip = False + trie = self.KEYWORD_TRIE + + while chars: + if skip: + result = 1 + else: + result, trie = in_trie(trie, char.upper()) + + if result == 0: + break + if result == 2: + word = chars + size += 1 + end = self._current - 1 + size + + if end < self.size: + char = self.sql[end] + is_space = char in self.WHITE_SPACE + + if not is_space or not prev_space: + if is_space: + char = " " + chars += char + prev_space = is_space + skip = False + else: + skip = True + else: + chars = None + + if not word: + if self._char in self.SINGLE_TOKENS: + token = self.SINGLE_TOKENS[self._char] + if token == TokenType.ANNOTATION: + self._scan_annotation() + return + self._add(token) + return + self._scan_var() + return + + if self._scan_string(word): + return + if self._scan_comment(word): + return + + self._advance(size - 1) + self._add(self.KEYWORDS[word.upper()]) + + def _scan_comment(self, comment_start): + if comment_start not in self.COMMENTS: + return False + + comment_end = self.COMMENTS[comment_start] + + if comment_end: + comment_end_size = len(comment_end) + + while not self._end and self._chars(comment_end_size) != comment_end: + self._advance() + self._advance(comment_end_size - 1) + else: + while not self._end and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK: + self._advance() + return True + + def _scan_annotation(self): + while ( + not self._end + and self.WHITE_SPACE.get(self._peek) != TokenType.BREAK + and self._peek != "," + ): + self._advance() + self._add(TokenType.ANNOTATION, self._text[1:]) + + def _scan_number(self): + decimal = False + scientific = 0 + + while True: + if self._peek.isdigit(): + self._advance() + elif self._peek == "." and not decimal: + decimal = True + self._advance() + elif self._peek in ("-", "+") and scientific == 1: + scientific += 1 + self._advance() + elif self._peek.upper() == "E" and not scientific: + scientific += 1 + self._advance() + elif self._peek.isalpha(): + self._add(TokenType.NUMBER) + literal = [] + while self._peek.isalpha(): + literal.append(self._peek.upper()) + self._advance() + literal = "".join(literal) + token_type = self.KEYWORDS.get(self.NUMERIC_LITERALS.get(literal)) + if token_type: + self._add(TokenType.DCOLON, "::") + return self._add(token_type, literal) + return self._advance(-len(literal)) + else: + return self._add(TokenType.NUMBER) + + def _scan_hex(self): + self._advance() + + while True: + char = self._peek.strip() + if char and char not in self.SINGLE_TOKENS: + self._advance() + else: + break + try: + self._add(TokenType.BIT_STRING, f"{int(self._text, 16):b}") + except ValueError: + self._add(TokenType.IDENTIFIER) + + def _scan_string(self, quote): + quote_end = self.QUOTES.get(quote) + if quote_end is None: + return False + + text = "" + self._advance(len(quote)) + quote_end_size = len(quote_end) + + while True: + if self._char == self.ESCAPE and self._peek == quote_end: + text += quote + self._advance(2) + else: + if self._chars(quote_end_size) == quote_end: + if quote_end_size > 1: + self._advance(quote_end_size - 1) + break + + if self._end: + raise RuntimeError( + f"Missing {quote} from {self._line}:{self._start}" + ) + text += self._char + self._advance() + + text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text + text = text.replace("\\\\", "\\") if self.ESCAPE == "\\" else text + self._add(TokenType.STRING, text) + return True + + def _scan_identifier(self, identifier_end): + while self._peek != identifier_end: + if self._end: + raise RuntimeError( + f"Missing {identifier_end} from {self._line}:{self._start}" + ) + self._advance() + self._advance() + self._add(TokenType.IDENTIFIER, self._text[1:-1]) + + def _scan_var(self): + while True: + char = self._peek.strip() + if char and char not in self.SINGLE_TOKENS: + self._advance() + else: + break + self._add(self.KEYWORDS.get(self._text.upper(), TokenType.VAR)) diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py new file mode 100644 index 0000000..e7ccb8e --- /dev/null +++ b/sqlglot/transforms.py @@ -0,0 +1,68 @@ +from sqlglot import expressions as exp + + +def unalias_group(expression): + """ + Replace references to select aliases in GROUP BY clauses. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() + 'SELECT a AS b FROM x GROUP BY 1' + """ + if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): + aliased_selects = { + e.alias: i + for i, e in enumerate(expression.parent.expressions, start=1) + if isinstance(e, exp.Alias) + } + + expression = expression.copy() + + for col in expression.find_all(exp.Column): + alias_index = aliased_selects.get(col.name) + if not col.table and alias_index: + col.replace(exp.Literal.number(alias_index)) + + return expression + + +def preprocess(transforms, to_sql): + """ + Create a new transform function that can be used a value in `Generator.TRANSFORMS` + to convert expressions to SQL. + + Args: + transforms (list[(exp.Expression) -> exp.Expression]): + Sequence of transform functions. These will be called in order. + to_sql ((sqlglot.generator.Generator, exp.Expression) -> str): + Final transform that converts the resulting expression to a SQL string. + Returns: + (sqlglot.generator.Generator, exp.Expression) -> str: + Function that can be used as a generator transform. + """ + + def _to_sql(self, expression): + expression = transforms[0](expression) + for t in transforms[1:]: + expression = t(expression) + return to_sql(self, expression) + + return _to_sql + + +def delegate(attr): + """ + Create a new method that delegates to `attr`. + + This is useful for creating `Generator.TRANSFORMS` functions that delegate + to existing generator methods. + """ + + def _transform(self, *args, **kwargs): + return getattr(self, attr)(*args, **kwargs) + + return _transform + + +UNALIAS_GROUP = {exp.Group: preprocess([unalias_group], delegate("group_sql"))} diff --git a/sqlglot/trie.py b/sqlglot/trie.py new file mode 100644 index 0000000..a234107 --- /dev/null +++ b/sqlglot/trie.py @@ -0,0 +1,27 @@ +def new_trie(keywords): + trie = {} + + for key in keywords: + current = trie + + for char in key: + current = current.setdefault(char, {}) + current[0] = True + + return trie + + +def in_trie(trie, key): + if not key: + return (0, trie) + + current = trie + + for char in key: + if char not in current: + return (0, current) + current = current[char] + + if 0 in current: + return (2, current) + return (1, current) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/dialects/__init__.py b/tests/dialects/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/dialects/__init__.py diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py new file mode 100644 index 0000000..1337c3d --- /dev/null +++ b/tests/dialects/test_bigquery.py @@ -0,0 +1,238 @@ +from sqlglot import ErrorLevel, ParseError, UnsupportedError, transpile +from tests.dialects.test_dialect import Validator + + +class TestBigQuery(Validator): + dialect = "bigquery" + + def test_bigquery(self): + self.validate_all( + '"""x"""', + write={ + "bigquery": "'x'", + "duckdb": "'x'", + "presto": "'x'", + "hive": "'x'", + "spark": "'x'", + }, + ) + self.validate_all( + '"""x\'"""', + write={ + "bigquery": "'x\\''", + "duckdb": "'x'''", + "presto": "'x'''", + "hive": "'x\\''", + "spark": "'x\\''", + }, + ) + self.validate_all( + r'r"""/\*.*\*/"""', + write={ + "bigquery": r"'/\\*.*\\*/'", + "duckdb": r"'/\*.*\*/'", + "presto": r"'/\*.*\*/'", + "hive": r"'/\\*.*\\*/'", + "spark": r"'/\\*.*\\*/'", + }, + ) + self.validate_all( + R'R"""/\*.*\*/"""', + write={ + "bigquery": R"'/\\*.*\\*/'", + "duckdb": R"'/\*.*\*/'", + "presto": R"'/\*.*\*/'", + "hive": R"'/\\*.*\\*/'", + "spark": R"'/\\*.*\\*/'", + }, + ) + self.validate_all( + "CAST(a AS INT64)", + write={ + "bigquery": "CAST(a AS INT64)", + "duckdb": "CAST(a AS BIGINT)", + "presto": "CAST(a AS BIGINT)", + "hive": "CAST(a AS BIGINT)", + "spark": "CAST(a AS LONG)", + }, + ) + self.validate_all( + "CAST(a AS NUMERIC)", + write={ + "bigquery": "CAST(a AS NUMERIC)", + "duckdb": "CAST(a AS DECIMAL)", + "presto": "CAST(a AS DECIMAL)", + "hive": "CAST(a AS DECIMAL)", + "spark": "CAST(a AS DECIMAL)", + }, + ) + self.validate_all( + "[1, 2, 3]", + read={ + "duckdb": "LIST_VALUE(1, 2, 3)", + "presto": "ARRAY[1, 2, 3]", + "hive": "ARRAY(1, 2, 3)", + "spark": "ARRAY(1, 2, 3)", + }, + write={ + "bigquery": "[1, 2, 3]", + "duckdb": "LIST_VALUE(1, 2, 3)", + "presto": "ARRAY[1, 2, 3]", + "hive": "ARRAY(1, 2, 3)", + "spark": "ARRAY(1, 2, 3)", + }, + ) + self.validate_all( + "SELECT * FROM UNNEST(['7', '14']) AS x", + read={ + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + }, + write={ + "bigquery": "SELECT * FROM UNNEST(['7', '14']) AS x", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS (x)", + "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS (x)", + }, + ) + + self.validate_all( + "x IS unknown", + write={ + "bigquery": "x IS NULL", + "duckdb": "x IS NULL", + "presto": "x IS NULL", + "hive": "x IS NULL", + "spark": "x IS NULL", + }, + ) + self.validate_all( + "current_datetime", + write={ + "bigquery": "CURRENT_DATETIME()", + "duckdb": "CURRENT_DATETIME()", + "presto": "CURRENT_DATETIME()", + "hive": "CURRENT_DATETIME()", + "spark": "CURRENT_DATETIME()", + }, + ) + self.validate_all( + "current_time", + write={ + "bigquery": "CURRENT_TIME()", + "duckdb": "CURRENT_TIME()", + "presto": "CURRENT_TIME()", + "hive": "CURRENT_TIME()", + "spark": "CURRENT_TIME()", + }, + ) + self.validate_all( + "current_timestamp", + write={ + "bigquery": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP()", + "postgres": "CURRENT_TIMESTAMP", + "presto": "CURRENT_TIMESTAMP()", + "hive": "CURRENT_TIMESTAMP()", + "spark": "CURRENT_TIMESTAMP()", + }, + ) + self.validate_all( + "current_timestamp()", + write={ + "bigquery": "CURRENT_TIMESTAMP()", + "duckdb": "CURRENT_TIMESTAMP()", + "postgres": "CURRENT_TIMESTAMP", + "presto": "CURRENT_TIMESTAMP()", + "hive": "CURRENT_TIMESTAMP()", + "spark": "CURRENT_TIMESTAMP()", + }, + ) + + self.validate_identity( + "SELECT ROW() OVER (y ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM x WINDOW y AS (PARTITION BY CATEGORY)" + ) + + self.validate_identity( + "SELECT LAST_VALUE(a IGNORE NULLS) OVER y FROM x WINDOW y AS (PARTITION BY CATEGORY)", + ) + + self.validate_all( + "CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRING>)", + "duckdb": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b TEXT>)", + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRING>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRING>)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a BIGINT, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a BIGINT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: LONG, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)", + }, + ) + self.validate_all( + "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", + write={ + "bigquery": "SELECT * FROM a WHERE b IN UNNEST([1, 2, 3])", + "mysql": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", + "presto": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY[1, 2, 3]))", + "hive": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", + "spark": "SELECT * FROM a WHERE b IN (SELECT UNNEST(ARRAY(1, 2, 3)))", + }, + ) + + # Reference: https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#set_operators + with self.assertRaises(UnsupportedError): + transpile( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + with self.assertRaises(UnsupportedError): + transpile( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + write="bigquery", + unsupported_level=ErrorLevel.RAISE, + ) + + with self.assertRaises(ParseError): + transpile("SELECT * FROM UNNEST(x) AS x(y)", read="bigquery") + + self.validate_all( + "DATE_SUB(CURRENT_DATE(), INTERVAL 1 DAY)", + write={ + "postgres": "CURRENT_DATE - INTERVAL '1' DAY", + }, + ) + self.validate_all( + "DATE_ADD(CURRENT_DATE(), INTERVAL 1 DAY)", + write={ + "bigquery": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", + "duckdb": "CURRENT_DATE + INTERVAL 1 DAY", + "mysql": "DATE_ADD(CURRENT_DATE, INTERVAL 1 DAY)", + "postgres": "CURRENT_DATE + INTERVAL '1' DAY", + "presto": "DATE_ADD(DAY, 1, CURRENT_DATE)", + "hive": "DATE_ADD(CURRENT_DATE, 1)", + "spark": "DATE_ADD(CURRENT_DATE, 1)", + }, + ) + self.validate_all( + "CURRENT_DATE('UTC')", + write={ + "mysql": "CURRENT_DATE AT TIME ZONE 'UTC'", + "postgres": "CURRENT_DATE AT TIME ZONE 'UTC'", + }, + ) + self.validate_all( + "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + write={ + "bigquery": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS FIRST LIMIT 10", + }, + ) diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py new file mode 100644 index 0000000..e5b1516 --- /dev/null +++ b/tests/dialects/test_clickhouse.py @@ -0,0 +1,25 @@ +from tests.dialects.test_dialect import Validator + + +class TestClickhouse(Validator): + dialect = "clickhouse" + + def test_clickhouse(self): + self.validate_identity("dictGet(x, 'y')") + self.validate_identity("SELECT * FROM x FINAL") + self.validate_identity("SELECT * FROM x AS y FINAL") + + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + }, + ) + + self.validate_all( + "CAST(1 AS NULLABLE(Int64))", + write={ + "clickhouse": "CAST(1 AS Nullable(BIGINT))", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py new file mode 100644 index 0000000..3993565 --- /dev/null +++ b/tests/dialects/test_dialect.py @@ -0,0 +1,981 @@ +import unittest + +from sqlglot import ( + Dialect, + Dialects, + ErrorLevel, + UnsupportedError, + parse_one, + transpile, +) + + +class Validator(unittest.TestCase): + dialect = None + + def validate(self, sql, target, **kwargs): + self.assertEqual(transpile(sql, **kwargs)[0], target) + + def validate_identity(self, sql): + self.assertEqual(transpile(sql, read=self.dialect, write=self.dialect)[0], sql) + + def validate_all(self, sql, read=None, write=None, pretty=False): + """ + Validate that: + 1. Everything in `read` transpiles to `sql` + 2. `sql` transpiles to everything in `write` + + Args: + sql (str): Main SQL expression + dialect (str): dialect of `sql` + read (dict): Mapping of dialect -> SQL + write (dict): Mapping of dialect -> SQL + """ + expression = parse_one(sql, read=self.dialect) + + for read_dialect, read_sql in (read or {}).items(): + with self.subTest(f"{read_dialect} -> {sql}"): + self.assertEqual( + parse_one(read_sql, read_dialect).sql( + self.dialect, unsupported_level=ErrorLevel.IGNORE + ), + sql, + ) + + for write_dialect, write_sql in (write or {}).items(): + with self.subTest(f"{sql} -> {write_dialect}"): + if write_sql is UnsupportedError: + with self.assertRaises(UnsupportedError): + expression.sql( + write_dialect, unsupported_level=ErrorLevel.RAISE + ) + else: + self.assertEqual( + expression.sql( + write_dialect, + unsupported_level=ErrorLevel.IGNORE, + pretty=pretty, + ), + write_sql, + ) + + +class TestDialect(Validator): + maxDiff = None + + def test_enum(self): + for dialect in Dialects: + self.assertIsNotNone(Dialect[dialect]) + self.assertIsNotNone(Dialect.get(dialect)) + self.assertIsNotNone(Dialect.get_or_raise(dialect)) + self.assertIsNotNone(Dialect[dialect.value]) + + def test_cast(self): + self.validate_all( + "CAST(a AS TEXT)", + write={ + "bigquery": "CAST(a AS STRING)", + "clickhouse": "CAST(a AS TEXT)", + "duckdb": "CAST(a AS TEXT)", + "mysql": "CAST(a AS TEXT)", + "hive": "CAST(a AS STRING)", + "oracle": "CAST(a AS CLOB)", + "postgres": "CAST(a AS TEXT)", + "presto": "CAST(a AS VARCHAR)", + "snowflake": "CAST(a AS TEXT)", + "spark": "CAST(a AS STRING)", + "starrocks": "CAST(a AS STRING)", + }, + ) + self.validate_all( + "CAST(a AS STRING)", + write={ + "bigquery": "CAST(a AS STRING)", + "duckdb": "CAST(a AS TEXT)", + "mysql": "CAST(a AS TEXT)", + "hive": "CAST(a AS STRING)", + "oracle": "CAST(a AS CLOB)", + "postgres": "CAST(a AS TEXT)", + "presto": "CAST(a AS VARCHAR)", + "snowflake": "CAST(a AS TEXT)", + "spark": "CAST(a AS STRING)", + "starrocks": "CAST(a AS STRING)", + }, + ) + self.validate_all( + "CAST(a AS VARCHAR)", + write={ + "bigquery": "CAST(a AS STRING)", + "duckdb": "CAST(a AS TEXT)", + "mysql": "CAST(a AS VARCHAR)", + "hive": "CAST(a AS STRING)", + "oracle": "CAST(a AS VARCHAR2)", + "postgres": "CAST(a AS VARCHAR)", + "presto": "CAST(a AS VARCHAR)", + "snowflake": "CAST(a AS VARCHAR)", + "spark": "CAST(a AS STRING)", + "starrocks": "CAST(a AS VARCHAR)", + }, + ) + self.validate_all( + "CAST(a AS VARCHAR(3))", + write={ + "bigquery": "CAST(a AS STRING(3))", + "duckdb": "CAST(a AS TEXT(3))", + "mysql": "CAST(a AS VARCHAR(3))", + "hive": "CAST(a AS VARCHAR(3))", + "oracle": "CAST(a AS VARCHAR2(3))", + "postgres": "CAST(a AS VARCHAR(3))", + "presto": "CAST(a AS VARCHAR(3))", + "snowflake": "CAST(a AS VARCHAR(3))", + "spark": "CAST(a AS VARCHAR(3))", + "starrocks": "CAST(a AS VARCHAR(3))", + }, + ) + self.validate_all( + "CAST(a AS SMALLINT)", + write={ + "bigquery": "CAST(a AS INT64)", + "duckdb": "CAST(a AS SMALLINT)", + "mysql": "CAST(a AS SMALLINT)", + "hive": "CAST(a AS SMALLINT)", + "oracle": "CAST(a AS NUMBER)", + "postgres": "CAST(a AS SMALLINT)", + "presto": "CAST(a AS SMALLINT)", + "snowflake": "CAST(a AS SMALLINT)", + "spark": "CAST(a AS SHORT)", + "sqlite": "CAST(a AS INTEGER)", + "starrocks": "CAST(a AS SMALLINT)", + }, + ) + self.validate_all( + "CAST(a AS DOUBLE)", + write={ + "bigquery": "CAST(a AS FLOAT64)", + "clickhouse": "CAST(a AS DOUBLE)", + "duckdb": "CAST(a AS DOUBLE)", + "mysql": "CAST(a AS DOUBLE)", + "hive": "CAST(a AS DOUBLE)", + "oracle": "CAST(a AS DOUBLE PRECISION)", + "postgres": "CAST(a AS DOUBLE PRECISION)", + "presto": "CAST(a AS DOUBLE)", + "snowflake": "CAST(a AS DOUBLE)", + "spark": "CAST(a AS DOUBLE)", + "starrocks": "CAST(a AS DOUBLE)", + }, + ) + self.validate_all( + "CAST(a AS TIMESTAMP)", write={"starrocks": "CAST(a AS DATETIME)"} + ) + self.validate_all( + "CAST(a AS TIMESTAMPTZ)", write={"starrocks": "CAST(a AS DATETIME)"} + ) + self.validate_all("CAST(a AS TINYINT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all("CAST(a AS SMALLINT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all("CAST(a AS BIGINT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all("CAST(a AS INT)", write={"oracle": "CAST(a AS NUMBER)"}) + self.validate_all( + "CAST(a AS DECIMAL)", + read={"oracle": "CAST(a AS NUMBER)"}, + write={"oracle": "CAST(a AS NUMBER)"}, + ) + + def test_time(self): + self.validate_all( + "STR_TO_TIME(x, '%Y-%m-%dT%H:%M:%S')", + read={ + "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", + }, + write={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "duckdb": "STRPTIME(x, '%Y-%m-%dT%H:%M:%S')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS TIMESTAMP)", + "presto": "DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S')", + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')", + }, + ) + self.validate_all( + "STR_TO_TIME('2020-01-01', '%Y-%m-%d')", + write={ + "duckdb": "STRPTIME('2020-01-01', '%Y-%m-%d')", + "hive": "CAST('2020-01-01' AS TIMESTAMP)", + "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d')", + "spark": "TO_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')", + }, + ) + self.validate_all( + "STR_TO_TIME(x, '%y')", + write={ + "duckdb": "STRPTIME(x, '%y')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy')) AS TIMESTAMP)", + "presto": "DATE_PARSE(x, '%y')", + "spark": "TO_TIMESTAMP(x, 'yy')", + }, + ) + self.validate_all( + "STR_TO_UNIX('2020-01-01', '%Y-%M-%d')", + write={ + "duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%M-%d'))", + "hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-mm-dd')", + "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%i-%d'))", + }, + ) + self.validate_all( + "TIME_STR_TO_DATE('2020-01-01')", + write={ + "duckdb": "CAST('2020-01-01' AS DATE)", + "hive": "TO_DATE('2020-01-01')", + "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + }, + ) + self.validate_all( + "TIME_STR_TO_TIME('2020-01-01')", + write={ + "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", + "hive": "CAST('2020-01-01' AS TIMESTAMP)", + "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + }, + ) + self.validate_all( + "TIME_STR_TO_UNIX('2020-01-01')", + write={ + "duckdb": "EPOCH(CAST('2020-01-01' AS TIMESTAMP))", + "hive": "UNIX_TIMESTAMP('2020-01-01')", + "presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%S'))", + }, + ) + self.validate_all( + "TIME_TO_STR(x, '%Y-%m-%d')", + write={ + "duckdb": "STRFTIME(x, '%Y-%m-%d')", + "hive": "DATE_FORMAT(x, 'yyyy-MM-dd')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d')", + }, + ) + self.validate_all( + "TIME_TO_TIME_STR(x)", + write={ + "duckdb": "CAST(x AS TEXT)", + "hive": "CAST(x AS STRING)", + "presto": "CAST(x AS VARCHAR)", + }, + ) + self.validate_all( + "TIME_TO_UNIX(x)", + write={ + "duckdb": "EPOCH(x)", + "hive": "UNIX_TIMESTAMP(x)", + "presto": "TO_UNIXTIME(x)", + }, + ) + self.validate_all( + "TS_OR_DS_TO_DATE_STR(x)", + write={ + "duckdb": "SUBSTRING(CAST(x AS TEXT), 1, 10)", + "hive": "SUBSTRING(CAST(x AS STRING), 1, 10)", + "presto": "SUBSTRING(CAST(x AS VARCHAR), 1, 10)", + }, + ) + self.validate_all( + "TS_OR_DS_TO_DATE(x)", + write={ + "duckdb": "CAST(x AS DATE)", + "hive": "TO_DATE(x)", + "presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)", + }, + ) + self.validate_all( + "TS_OR_DS_TO_DATE(x, '%-d')", + write={ + "duckdb": "CAST(STRPTIME(x, '%-d') AS DATE)", + "hive": "TO_DATE(x, 'd')", + "presto": "CAST(DATE_PARSE(x, '%e') AS DATE)", + "spark": "TO_DATE(x, 'd')", + }, + ) + self.validate_all( + "UNIX_TO_STR(x, y)", + write={ + "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), y)", + "hive": "FROM_UNIXTIME(x, y)", + "presto": "DATE_FORMAT(FROM_UNIXTIME(x), y)", + }, + ) + self.validate_all( + "UNIX_TO_TIME(x)", + write={ + "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", + "hive": "FROM_UNIXTIME(x)", + "presto": "FROM_UNIXTIME(x)", + }, + ) + self.validate_all( + "UNIX_TO_TIME_STR(x)", + write={ + "duckdb": "CAST(TO_TIMESTAMP(CAST(x AS BIGINT)) AS TEXT)", + "hive": "FROM_UNIXTIME(x)", + "presto": "CAST(FROM_UNIXTIME(x) AS VARCHAR)", + }, + ) + self.validate_all( + "DATE_TO_DATE_STR(x)", + write={ + "duckdb": "CAST(x AS TEXT)", + "hive": "CAST(x AS STRING)", + "presto": "CAST(x AS VARCHAR)", + }, + ) + self.validate_all( + "DATE_TO_DI(x)", + write={ + "duckdb": "CAST(STRFTIME(x, '%Y%m%d') AS INT)", + "hive": "CAST(DATE_FORMAT(x, 'yyyyMMdd') AS INT)", + "presto": "CAST(DATE_FORMAT(x, '%Y%m%d') AS INT)", + }, + ) + self.validate_all( + "DI_TO_DATE(x)", + write={ + "duckdb": "CAST(STRPTIME(CAST(x AS TEXT), '%Y%m%d') AS DATE)", + "hive": "TO_DATE(CAST(x AS STRING), 'yyyyMMdd')", + "presto": "CAST(DATE_PARSE(CAST(x AS VARCHAR), '%Y%m%d') AS DATE)", + }, + ) + self.validate_all( + "TS_OR_DI_TO_DI(x)", + write={ + "duckdb": "CAST(SUBSTR(REPLACE(CAST(x AS TEXT), '-', ''), 1, 8) AS INT)", + "hive": "CAST(SUBSTR(REPLACE(CAST(x AS STRING), '-', ''), 1, 8) AS INT)", + "presto": "CAST(SUBSTR(REPLACE(CAST(x AS VARCHAR), '-', ''), 1, 8) AS INT)", + "spark": "CAST(SUBSTR(REPLACE(CAST(x AS STRING), '-', ''), 1, 8) AS INT)", + }, + ) + self.validate_all( + "DATE_ADD(x, 1, 'day')", + read={ + "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + }, + write={ + "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "duckdb": "x + INTERVAL 1 day", + "hive": "DATE_ADD(x, 1)", + "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "postgres": "x + INTERVAL '1' 'day'", + "presto": "DATE_ADD('day', 1, x)", + "spark": "DATE_ADD(x, 1)", + "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + }, + ) + self.validate_all( + "DATE_ADD(x, y, 'day')", + write={ + "postgres": UnsupportedError, + }, + ) + self.validate_all( + "DATE_ADD(x, 1)", + write={ + "bigquery": "DATE_ADD(x, INTERVAL 1 'day')", + "duckdb": "x + INTERVAL 1 DAY", + "hive": "DATE_ADD(x, 1)", + "mysql": "DATE_ADD(x, INTERVAL 1 DAY)", + "presto": "DATE_ADD('day', 1, x)", + "spark": "DATE_ADD(x, 1)", + "starrocks": "DATE_ADD(x, INTERVAL 1 DAY)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'day')", + write={ + "mysql": "DATE(x)", + "starrocks": "DATE(x)", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'week')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', WEEK(x, 1), ' 1'), '%Y %u %w')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'month')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', MONTH(x), ' 1'), '%Y %c %e')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'quarter')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' ', QUARTER(x) * 3 - 2, ' 1'), '%Y %c %e')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'year')", + write={ + "mysql": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", + "starrocks": "STR_TO_DATE(CONCAT(YEAR(x), ' 1 1'), '%Y %c %e')", + }, + ) + self.validate_all( + "DATE_TRUNC(x, 'millenium')", + write={ + "mysql": UnsupportedError, + "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "STR_TO_DATE(x, '%Y-%m-%dT%H:%M:%S')", + read={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + }, + write={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%dT%H:%i:%S')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yyyy-MM-ddTHH:mm:ss')) AS DATE)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%dT%H:%i:%S') AS DATE)", + "spark": "TO_DATE(x, 'yyyy-MM-ddTHH:mm:ss')", + }, + ) + self.validate_all( + "STR_TO_DATE(x, '%Y-%m-%d')", + write={ + "mysql": "STR_TO_DATE(x, '%Y-%m-%d')", + "starrocks": "STR_TO_DATE(x, '%Y-%m-%d')", + "hive": "CAST(x AS DATE)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", + "spark": "TO_DATE(x)", + }, + ) + self.validate_all( + "DATE_STR_TO_DATE(x)", + write={ + "duckdb": "CAST(x AS DATE)", + "hive": "TO_DATE(x)", + "presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)", + "spark": "TO_DATE(x)", + }, + ) + self.validate_all( + "TS_OR_DS_ADD('2021-02-01', 1, 'DAY')", + write={ + "duckdb": "CAST('2021-02-01' AS DATE) + INTERVAL 1 DAY", + "hive": "DATE_ADD('2021-02-01', 1)", + "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2021-02-01', 1, 10), '%Y-%m-%d'))", + "spark": "DATE_ADD('2021-02-01', 1)", + }, + ) + self.validate_all( + "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", + write={ + "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", + "hive": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", + "presto": "DATE_ADD('day', 1, CAST('2020-01-01' AS DATE))", + "spark": "DATE_ADD(CAST('2020-01-01' AS DATE), 1)", + }, + ) + + for unit in ("DAY", "MONTH", "YEAR"): + self.validate_all( + f"{unit}(x)", + read={ + dialect: f"{unit}(x)" + for dialect in ( + "bigquery", + "duckdb", + "mysql", + "presto", + "starrocks", + ) + }, + write={ + dialect: f"{unit}(x)" + for dialect in ( + "bigquery", + "duckdb", + "mysql", + "presto", + "hive", + "spark", + "starrocks", + ) + }, + ) + + def test_array(self): + self.validate_all( + "ARRAY(0, 1, 2)", + write={ + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "spark": "ARRAY(0, 1, 2)", + }, + ) + self.validate_all( + "ARRAY_SIZE(x)", + write={ + "bigquery": "ARRAY_LENGTH(x)", + "duckdb": "ARRAY_LENGTH(x)", + "presto": "CARDINALITY(x)", + "spark": "SIZE(x)", + }, + ) + self.validate_all( + "ARRAY_SUM(ARRAY(1, 2))", + write={ + "trino": "REDUCE(ARRAY[1, 2], 0, (acc, x) -> acc + x, acc -> acc)", + "duckdb": "LIST_SUM(LIST_VALUE(1, 2))", + "hive": "ARRAY_SUM(ARRAY(1, 2))", + "presto": "ARRAY_SUM(ARRAY[1, 2])", + "spark": "AGGREGATE(ARRAY(1, 2), 0, (acc, x) -> acc + x, acc -> acc)", + }, + ) + self.validate_all( + "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + write={ + "trino": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "duckdb": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "hive": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "presto": "REDUCE(x, 0, (acc, x) -> acc + x, acc -> acc)", + "spark": "AGGREGATE(x, 0, (acc, x) -> acc + x, acc -> acc)", + }, + ) + + def test_order_by(self): + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "bigquery": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) + + def test_json(self): + self.validate_all( + "JSON_EXTRACT(x, 'y')", + read={ + "postgres": "x->'y'", + "presto": "JSON_EXTRACT(x, 'y')", + }, + write={ + "postgres": "x->'y'", + "presto": "JSON_EXTRACT(x, 'y')", + }, + ) + self.validate_all( + "JSON_EXTRACT_SCALAR(x, 'y')", + read={ + "postgres": "x->>'y'", + "presto": "JSON_EXTRACT_SCALAR(x, 'y')", + }, + write={ + "postgres": "x->>'y'", + "presto": "JSON_EXTRACT_SCALAR(x, 'y')", + }, + ) + self.validate_all( + "JSONB_EXTRACT(x, 'y')", + read={ + "postgres": "x#>'y'", + }, + write={ + "postgres": "x#>'y'", + }, + ) + self.validate_all( + "JSONB_EXTRACT_SCALAR(x, 'y')", + read={ + "postgres": "x#>>'y'", + }, + write={ + "postgres": "x#>>'y'", + }, + ) + + def test_cross_join(self): + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(y) AS t (a)", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + }, + ) + self.validate_all( + "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t (a, b)", + write={ + "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y, z) AS t(a, b)", + "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) t AS b", + }, + ) + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t (a)", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t(a)", + "spark": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + }, + ) + + def test_set_operators(self): + self.validate_all( + "SELECT * FROM a UNION SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION SELECT * FROM b", + "presto": "SELECT * FROM a UNION SELECT * FROM b", + "spark": "SELECT * FROM a UNION SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION SELECT * FROM b", + "presto": "SELECT * FROM a UNION SELECT * FROM b", + "spark": "SELECT * FROM a UNION SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a UNION ALL SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b", + "presto": "SELECT * FROM a UNION ALL SELECT * FROM b", + "spark": "SELECT * FROM a UNION ALL SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a UNION ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION ALL SELECT * FROM b", + "presto": "SELECT * FROM a UNION ALL SELECT * FROM b", + "spark": "SELECT * FROM a UNION ALL SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", + }, + write={ + "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a UNION DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a UNION SELECT * FROM b", + "presto": "SELECT * FROM a UNION SELECT * FROM b", + "spark": "SELECT * FROM a UNION SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a INTERSECT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "presto": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + "spark": "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + write={ + "bigquery": "SELECT * FROM a EXCEPT DISTINCT SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT SELECT * FROM b", + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + read={ + "bigquery": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "duckdb": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "presto": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + "spark": "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + }, + ) + + def test_operators(self): + self.validate_all( + "x ILIKE '%y'", + read={ + "clickhouse": "x ILIKE '%y'", + "duckdb": "x ILIKE '%y'", + "postgres": "x ILIKE '%y'", + "snowflake": "x ILIKE '%y'", + }, + write={ + "bigquery": "LOWER(x) LIKE '%y'", + "clickhouse": "x ILIKE '%y'", + "duckdb": "x ILIKE '%y'", + "hive": "LOWER(x) LIKE '%y'", + "mysql": "LOWER(x) LIKE '%y'", + "oracle": "LOWER(x) LIKE '%y'", + "postgres": "x ILIKE '%y'", + "presto": "LOWER(x) LIKE '%y'", + "snowflake": "x ILIKE '%y'", + "spark": "LOWER(x) LIKE '%y'", + "sqlite": "LOWER(x) LIKE '%y'", + "starrocks": "LOWER(x) LIKE '%y'", + "trino": "LOWER(x) LIKE '%y'", + }, + ) + self.validate_all( + "SELECT * FROM a ORDER BY col_a NULLS LAST", + write={ + "mysql": UnsupportedError, + "starrocks": UnsupportedError, + }, + ) + self.validate_all( + "STR_POSITION(x, 'a')", + write={ + "duckdb": "STRPOS(x, 'a')", + "presto": "STRPOS(x, 'a')", + "spark": "LOCATE('a', x)", + }, + ) + self.validate_all( + "CONCAT_WS('-', 'a', 'b')", + write={ + "duckdb": "CONCAT_WS('-', 'a', 'b')", + "presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')", + "hive": "CONCAT_WS('-', 'a', 'b')", + "spark": "CONCAT_WS('-', 'a', 'b')", + }, + ) + + self.validate_all( + "CONCAT_WS('-', x)", + write={ + "duckdb": "CONCAT_WS('-', x)", + "presto": "ARRAY_JOIN(x, '-')", + "hive": "CONCAT_WS('-', x)", + "spark": "CONCAT_WS('-', x)", + }, + ) + self.validate_all( + "IF(x > 1, 1, 0)", + write={ + "duckdb": "CASE WHEN x > 1 THEN 1 ELSE 0 END", + "presto": "IF(x > 1, 1, 0)", + "hive": "IF(x > 1, 1, 0)", + "spark": "IF(x > 1, 1, 0)", + "tableau": "IF x > 1 THEN 1 ELSE 0 END", + }, + ) + self.validate_all( + "CASE WHEN 1 THEN x ELSE 0 END", + write={ + "duckdb": "CASE WHEN 1 THEN x ELSE 0 END", + "presto": "CASE WHEN 1 THEN x ELSE 0 END", + "hive": "CASE WHEN 1 THEN x ELSE 0 END", + "spark": "CASE WHEN 1 THEN x ELSE 0 END", + "tableau": "CASE WHEN 1 THEN x ELSE 0 END", + }, + ) + self.validate_all( + "x[y]", + write={ + "duckdb": "x[y]", + "presto": "x[y]", + "hive": "x[y]", + "spark": "x[y]", + }, + ) + self.validate_all( + """'["x"]'""", + write={ + "duckdb": """'["x"]'""", + "presto": """'["x"]'""", + "hive": """'["x"]'""", + "spark": """'["x"]'""", + }, + ) + + self.validate_all( + 'true or null as "foo"', + write={ + "bigquery": "TRUE OR NULL AS `foo`", + "duckdb": 'TRUE OR NULL AS "foo"', + "presto": 'TRUE OR NULL AS "foo"', + "hive": "TRUE OR NULL AS `foo`", + "spark": "TRUE OR NULL AS `foo`", + }, + ) + self.validate_all( + "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) as foo FROM baz", + write={ + "bigquery": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", + "duckdb": "SELECT CASE WHEN COALESCE(bar, 0) = 1 THEN TRUE ELSE FALSE END AS foo FROM baz", + "presto": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", + "hive": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", + "spark": "SELECT IF(COALESCE(bar, 0) = 1, TRUE, FALSE) AS foo FROM baz", + }, + ) + self.validate_all( + "LEVENSHTEIN(col1, col2)", + write={ + "duckdb": "LEVENSHTEIN(col1, col2)", + "presto": "LEVENSHTEIN_DISTANCE(col1, col2)", + "hive": "LEVENSHTEIN(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + }, + ) + self.validate_all( + "LEVENSHTEIN(coalesce(col1, col2), coalesce(col2, col1))", + write={ + "duckdb": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + "presto": "LEVENSHTEIN_DISTANCE(COALESCE(col1, col2), COALESCE(col2, col1))", + "hive": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + "spark": "LEVENSHTEIN(COALESCE(col1, col2), COALESCE(col2, col1))", + }, + ) + self.validate_all( + "ARRAY_FILTER(the_array, x -> x > 0)", + write={ + "presto": "FILTER(the_array, x -> x > 0)", + "hive": "FILTER(the_array, x -> x > 0)", + "spark": "FILTER(the_array, x -> x > 0)", + }, + ) + self.validate_all( + "SELECT a AS b FROM x GROUP BY b", + write={ + "duckdb": "SELECT a AS b FROM x GROUP BY b", + "presto": "SELECT a AS b FROM x GROUP BY 1", + "hive": "SELECT a AS b FROM x GROUP BY 1", + "oracle": "SELECT a AS b FROM x GROUP BY 1", + "spark": "SELECT a AS b FROM x GROUP BY 1", + }, + ) + self.validate_all( + "SELECT x FROM y LIMIT 10", + write={ + "sqlite": "SELECT x FROM y LIMIT 10", + "oracle": "SELECT x FROM y FETCH FIRST 10 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT x FROM y LIMIT 10 OFFSET 5", + write={ + "sqlite": "SELECT x FROM y LIMIT 10 OFFSET 5", + "oracle": "SELECT x FROM y OFFSET 5 ROWS FETCH FIRST 10 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY", + write={ + "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", + write={ + "oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY", + }, + ) + self.validate_all( + '"x" + "y"', + read={ + "clickhouse": '`x` + "y"', + "sqlite": '`x` + "y"', + }, + ) + self.validate_all( + "[1, 2]", + write={ + "bigquery": "[1, 2]", + "clickhouse": "[1, 2]", + }, + ) + self.validate_all( + "SELECT * FROM VALUES ('x'), ('y') AS t(z)", + write={ + "spark": "SELECT * FROM (VALUES ('x'), ('y')) AS t(z)", + }, + ) + self.validate_all( + "CREATE TABLE t (c CHAR, nc NCHAR, v1 VARCHAR, v2 VARCHAR2, nv NVARCHAR, nv2 NVARCHAR2)", + write={ + "hive": "CREATE TABLE t (c CHAR, nc CHAR, v1 STRING, v2 STRING, nv STRING, nv2 STRING)", + "oracle": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR2, v2 VARCHAR2, nv NVARCHAR2, nv2 NVARCHAR2)", + "postgres": "CREATE TABLE t (c CHAR, nc CHAR, v1 VARCHAR, v2 VARCHAR, nv VARCHAR, nv2 VARCHAR)", + "sqlite": "CREATE TABLE t (c TEXT, nc TEXT, v1 TEXT, v2 TEXT, nv TEXT, nv2 TEXT)", + }, + ) + self.validate_all( + "POWER(1.2, 3.4)", + read={ + "hive": "pow(1.2, 3.4)", + "postgres": "power(1.2, 3.4)", + }, + ) + self.validate_all( + "CREATE INDEX my_idx ON tbl (a, b)", + read={ + "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", + "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + }, + write={ + "hive": "CREATE INDEX my_idx ON TABLE tbl (a, b)", + "postgres": "CREATE INDEX my_idx ON tbl (a, b)", + "sqlite": "CREATE INDEX my_idx ON tbl (a, b)", + }, + ) + self.validate_all( + "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + read={ + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + }, + write={ + "hive": "CREATE UNIQUE INDEX my_idx ON TABLE tbl (a, b)", + "postgres": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + "sqlite": "CREATE UNIQUE INDEX my_idx ON tbl (a, b)", + }, + ) + self.validate_all( + "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 TEXT, c2 TEXT(1024))", + write={ + "hive": "CREATE TABLE t (b1 BINARY, b2 BINARY(1024), c1 STRING, c2 STRING(1024))", + "oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))", + "postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))", + "sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))", + }, + ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py new file mode 100644 index 0000000..501301f --- /dev/null +++ b/tests/dialects/test_duckdb.py @@ -0,0 +1,249 @@ +from tests.dialects.test_dialect import Validator + + +class TestDuckDB(Validator): + dialect = "duckdb" + + def test_time(self): + self.validate_all( + "EPOCH(x)", + read={ + "presto": "TO_UNIXTIME(x)", + }, + write={ + "bigquery": "TIME_TO_UNIX(x)", + "duckdb": "EPOCH(x)", + "presto": "TO_UNIXTIME(x)", + "spark": "UNIX_TIMESTAMP(x)", + }, + ) + self.validate_all( + "EPOCH_MS(x)", + write={ + "bigquery": "UNIX_TO_TIME(x / 1000)", + "duckdb": "TO_TIMESTAMP(CAST(x / 1000 AS BIGINT))", + "presto": "FROM_UNIXTIME(x / 1000)", + "spark": "FROM_UNIXTIME(x / 1000)", + }, + ) + self.validate_all( + "STRFTIME(x, '%y-%-m-%S')", + write={ + "bigquery": "TIME_TO_STR(x, '%y-%-m-%S')", + "duckdb": "STRFTIME(x, '%y-%-m-%S')", + "postgres": "TO_CHAR(x, 'YY-FMMM-SS')", + "presto": "DATE_FORMAT(x, '%y-%c-%S')", + "spark": "DATE_FORMAT(x, 'yy-M-ss')", + }, + ) + self.validate_all( + "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", + write={ + "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "STRPTIME(x, '%y-%-m')", + write={ + "bigquery": "STR_TO_TIME(x, '%y-%-m')", + "duckdb": "STRPTIME(x, '%y-%-m')", + "presto": "DATE_PARSE(x, '%y-%c')", + "hive": "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'yy-M')) AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'yy-M')", + }, + ) + self.validate_all( + "TO_TIMESTAMP(x)", + write={ + "duckdb": "CAST(x AS TIMESTAMP)", + "presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')", + "hive": "CAST(x AS TIMESTAMP)", + }, + ) + + def test_duckdb(self): + self.validate_all( + "LIST_VALUE(0, 1, 2)", + write={ + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "spark": "ARRAY(0, 1, 2)", + }, + ) + self.validate_all( + "REGEXP_MATCHES(x, y)", + write={ + "duckdb": "REGEXP_MATCHES(x, y)", + "presto": "REGEXP_LIKE(x, y)", + "hive": "x RLIKE y", + "spark": "x RLIKE y", + }, + ) + self.validate_all( + "STR_SPLIT(x, 'a')", + write={ + "duckdb": "STR_SPLIT(x, 'a')", + "presto": "SPLIT(x, 'a')", + "hive": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + "spark": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + }, + ) + self.validate_all( + "STRING_TO_ARRAY(x, 'a')", + write={ + "duckdb": "STR_SPLIT(x, 'a')", + "presto": "SPLIT(x, 'a')", + "hive": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + "spark": "SPLIT(x, CONCAT('\\\\Q', 'a'))", + }, + ) + self.validate_all( + "STR_SPLIT_REGEX(x, 'a')", + write={ + "duckdb": "STR_SPLIT_REGEX(x, 'a')", + "presto": "REGEXP_SPLIT(x, 'a')", + "hive": "SPLIT(x, 'a')", + "spark": "SPLIT(x, 'a')", + }, + ) + self.validate_all( + "STRUCT_EXTRACT(x, 'abc')", + write={ + "duckdb": "STRUCT_EXTRACT(x, 'abc')", + "presto": 'x."abc"', + "hive": "x.`abc`", + "spark": "x.`abc`", + }, + ) + self.validate_all( + "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')", + write={ + "duckdb": "STRUCT_EXTRACT(STRUCT_EXTRACT(x, 'y'), 'abc')", + "presto": 'x."y"."abc"', + "hive": "x.`y`.`abc`", + "spark": "x.`y`.`abc`", + }, + ) + + self.validate_all( + "QUANTILE(x, 0.5)", + write={ + "duckdb": "QUANTILE(x, 0.5)", + "presto": "APPROX_PERCENTILE(x, 0.5)", + "hive": "PERCENTILE(x, 0.5)", + "spark": "PERCENTILE(x, 0.5)", + }, + ) + + self.validate_all( + "CAST(x AS DATE)", + write={ + "duckdb": "CAST(x AS DATE)", + "": "CAST(x AS DATE)", + }, + ) + self.validate_all( + "UNNEST(x)", + read={ + "spark": "EXPLODE(x)", + }, + write={ + "duckdb": "UNNEST(x)", + "spark": "EXPLODE(x)", + }, + ) + + self.validate_all( + "1d", + write={ + "duckdb": "1 AS d", + "spark": "1 AS d", + }, + ) + self.validate_all( + "CAST(1 AS DOUBLE)", + read={ + "hive": "1d", + "spark": "1d", + }, + ) + self.validate_all( + "POWER(CAST(2 AS SMALLINT), 3)", + read={ + "hive": "POW(2S, 3)", + "spark": "POW(2S, 3)", + }, + ) + self.validate_all( + "LIST_SUM(LIST_VALUE(1, 2))", + read={ + "spark": "ARRAY_SUM(ARRAY(1, 2))", + }, + ) + self.validate_all( + "IF(y <> 0, x / y, NULL)", + read={ + "bigquery": "SAFE_DIVIDE(x, y)", + }, + ) + self.validate_all( + "STRUCT_PACK(x := 1, y := '2')", + write={ + "duckdb": "STRUCT_PACK(x := 1, y := '2')", + "spark": "STRUCT(x = 1, y = '2')", + }, + ) + self.validate_all( + "ARRAY_SORT(x)", + write={ + "duckdb": "ARRAY_SORT(x)", + "presto": "ARRAY_SORT(x)", + "hive": "SORT_ARRAY(x)", + "spark": "SORT_ARRAY(x)", + }, + ) + self.validate_all( + "ARRAY_REVERSE_SORT(x)", + write={ + "duckdb": "ARRAY_REVERSE_SORT(x)", + "presto": "ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)", + "hive": "SORT_ARRAY(x, FALSE)", + "spark": "SORT_ARRAY(x, FALSE)", + }, + ) + self.validate_all( + "LIST_REVERSE_SORT(x)", + write={ + "duckdb": "ARRAY_REVERSE_SORT(x)", + "presto": "ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)", + "hive": "SORT_ARRAY(x, FALSE)", + "spark": "SORT_ARRAY(x, FALSE)", + }, + ) + self.validate_all( + "LIST_SORT(x)", + write={ + "duckdb": "ARRAY_SORT(x)", + "presto": "ARRAY_SORT(x)", + "hive": "SORT_ARRAY(x)", + "spark": "SORT_ARRAY(x)", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) + self.validate_all( + "MONTH('2021-03-01')", + write={ + "duckdb": "MONTH('2021-03-01')", + "presto": "MONTH('2021-03-01')", + "hive": "MONTH('2021-03-01')", + "spark": "MONTH('2021-03-01')", + }, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py new file mode 100644 index 0000000..eccd75a --- /dev/null +++ b/tests/dialects/test_hive.py @@ -0,0 +1,541 @@ +from tests.dialects.test_dialect import Validator + + +class TestHive(Validator): + dialect = "hive" + + def test_bits(self): + self.validate_all( + "x & 1", + write={ + "duckdb": "x & 1", + "presto": "BITWISE_AND(x, 1)", + "hive": "x & 1", + "spark": "x & 1", + }, + ) + self.validate_all( + "~x", + write={ + "duckdb": "~x", + "presto": "BITWISE_NOT(x)", + "hive": "~x", + "spark": "~x", + }, + ) + self.validate_all( + "x | 1", + write={ + "duckdb": "x | 1", + "presto": "BITWISE_OR(x, 1)", + "hive": "x | 1", + "spark": "x | 1", + }, + ) + self.validate_all( + "x << 1", + read={ + "spark": "SHIFTLEFT(x, 1)", + }, + write={ + "duckdb": "x << 1", + "presto": "BITWISE_ARITHMETIC_SHIFT_LEFT(x, 1)", + "hive": "x << 1", + "spark": "SHIFTLEFT(x, 1)", + }, + ) + self.validate_all( + "x >> 1", + read={ + "spark": "SHIFTRIGHT(x, 1)", + }, + write={ + "duckdb": "x >> 1", + "presto": "BITWISE_ARITHMETIC_SHIFT_RIGHT(x, 1)", + "hive": "x >> 1", + "spark": "SHIFTRIGHT(x, 1)", + }, + ) + self.validate_all( + "x & 1 > 0", + write={ + "duckdb": "x & 1 > 0", + "presto": "BITWISE_AND(x, 1) > 0", + "hive": "x & 1 > 0", + "spark": "x & 1 > 0", + }, + ) + + def test_cast(self): + self.validate_all( + "1s", + write={ + "duckdb": "CAST(1 AS SMALLINT)", + "presto": "CAST(1 AS SMALLINT)", + "hive": "CAST(1 AS SMALLINT)", + "spark": "CAST(1 AS SHORT)", + }, + ) + self.validate_all( + "1S", + write={ + "duckdb": "CAST(1 AS SMALLINT)", + "presto": "CAST(1 AS SMALLINT)", + "hive": "CAST(1 AS SMALLINT)", + "spark": "CAST(1 AS SHORT)", + }, + ) + self.validate_all( + "1Y", + write={ + "duckdb": "CAST(1 AS TINYINT)", + "presto": "CAST(1 AS TINYINT)", + "hive": "CAST(1 AS TINYINT)", + "spark": "CAST(1 AS BYTE)", + }, + ) + self.validate_all( + "1L", + write={ + "duckdb": "CAST(1 AS BIGINT)", + "presto": "CAST(1 AS BIGINT)", + "hive": "CAST(1 AS BIGINT)", + "spark": "CAST(1 AS LONG)", + }, + ) + self.validate_all( + "1.0bd", + write={ + "duckdb": "CAST(1.0 AS DECIMAL)", + "presto": "CAST(1.0 AS DECIMAL)", + "hive": "CAST(1.0 AS DECIMAL)", + "spark": "CAST(1.0 AS DECIMAL)", + }, + ) + self.validate_all( + "CAST(1 AS INT)", + read={ + "presto": "TRY_CAST(1 AS INT)", + }, + write={ + "duckdb": "TRY_CAST(1 AS INT)", + "presto": "TRY_CAST(1 AS INTEGER)", + "hive": "CAST(1 AS INT)", + "spark": "CAST(1 AS INT)", + }, + ) + + def test_ddl(self): + self.validate_all( + "CREATE TABLE test STORED AS parquet TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'parquet', x = '1', Z = '2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('x' = '1', 'Z' = '2') AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + write={ + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + }, + ) + + def test_lateral_view(self): + self.validate_all( + "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", + write={ + "presto": "SELECT a, b FROM x CROSS JOIN UNNEST(y) AS t(a) CROSS JOIN UNNEST(z) AS u(b)", + "hive": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", + "spark": "SELECT a, b FROM x LATERAL VIEW EXPLODE(y) t AS a LATERAL VIEW EXPLODE(z) u AS b", + }, + ) + self.validate_all( + "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(y) t AS a", + }, + ) + self.validate_all( + "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(y) WITH ORDINALITY AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW POSEXPLODE(y) t AS a", + }, + ) + self.validate_all( + "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + }, + ) + + def test_quotes(self): + self.validate_all( + "'\\''", + write={ + "duckdb": "''''", + "presto": "''''", + "hive": "'\\''", + "spark": "'\\''", + }, + ) + self.validate_all( + "'\"x\"'", + write={ + "duckdb": "'\"x\"'", + "presto": "'\"x\"'", + "hive": "'\"x\"'", + "spark": "'\"x\"'", + }, + ) + self.validate_all( + "\"'x'\"", + write={ + "duckdb": "'''x'''", + "presto": "'''x'''", + "hive": "'\\'x\\''", + "spark": "'\\'x\\''", + }, + ) + self.validate_all( + "'\\\\a'", + read={ + "presto": "'\\a'", + }, + write={ + "duckdb": "'\\a'", + "presto": "'\\a'", + "hive": "'\\\\a'", + "spark": "'\\\\a'", + }, + ) + + def test_regex(self): + self.validate_all( + "a RLIKE 'x'", + write={ + "duckdb": "REGEXP_MATCHES(a, 'x')", + "presto": "REGEXP_LIKE(a, 'x')", + "hive": "a RLIKE 'x'", + "spark": "a RLIKE 'x'", + }, + ) + + self.validate_all( + "a REGEXP 'x'", + write={ + "duckdb": "REGEXP_MATCHES(a, 'x')", + "presto": "REGEXP_LIKE(a, 'x')", + "hive": "a RLIKE 'x'", + "spark": "a RLIKE 'x'", + }, + ) + + def test_time(self): + self.validate_all( + "DATEDIFF(a, b)", + write={ + "duckdb": "DATE_DIFF('day', CAST(b AS DATE), CAST(a AS DATE))", + "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(b AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(a AS VARCHAR), 1, 10) AS DATE))", + "hive": "DATEDIFF(TO_DATE(a), TO_DATE(b))", + "spark": "DATEDIFF(TO_DATE(a), TO_DATE(b))", + "": "DATE_DIFF(TS_OR_DS_TO_DATE(a), TS_OR_DS_TO_DATE(b))", + }, + ) + self.validate_all( + """from_unixtime(x, "yyyy-MM-dd'T'HH")""", + write={ + "duckdb": "STRFTIME(TO_TIMESTAMP(CAST(x AS BIGINT)), '%Y-%m-%d''T''%H')", + "presto": "DATE_FORMAT(FROM_UNIXTIME(x), '%Y-%m-%d''T''%H')", + "hive": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", + "spark": "FROM_UNIXTIME(x, 'yyyy-MM-dd\\'T\\'HH')", + }, + ) + self.validate_all( + "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + write={ + "duckdb": "STRFTIME('2020-01-01', '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT('2020-01-01', '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + "spark": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "DATE_ADD('2020-01-01', 1)", + write={ + "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 DAY", + "presto": "DATE_ADD('DAY', 1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))", + "hive": "DATE_ADD('2020-01-01', 1)", + "spark": "DATE_ADD('2020-01-01', 1)", + "": "TS_OR_DS_ADD('2020-01-01', 1, 'DAY')", + }, + ) + self.validate_all( + "DATE_SUB('2020-01-01', 1)", + write={ + "duckdb": "CAST('2020-01-01' AS DATE) + INTERVAL 1 * -1 DAY", + "presto": "DATE_ADD('DAY', 1 * -1, DATE_PARSE(SUBSTR('2020-01-01', 1, 10), '%Y-%m-%d'))", + "hive": "DATE_ADD('2020-01-01', 1 * -1)", + "spark": "DATE_ADD('2020-01-01', 1 * -1)", + "": "TS_OR_DS_ADD('2020-01-01', 1 * -1, 'DAY')", + }, + ) + self.validate_all( + "DATEDIFF(TO_DATE(y), x)", + write={ + "duckdb": "DATE_DIFF('day', CAST(x AS DATE), CAST(CAST(y AS DATE) AS DATE))", + "presto": "DATE_DIFF('day', CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE), CAST(SUBSTR(CAST(CAST(SUBSTR(CAST(y AS VARCHAR), 1, 10) AS DATE) AS VARCHAR), 1, 10) AS DATE))", + "hive": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", + "spark": "DATEDIFF(TO_DATE(TO_DATE(y)), TO_DATE(x))", + "": "DATE_DIFF(TS_OR_DS_TO_DATE(TS_OR_DS_TO_DATE(y)), TS_OR_DS_TO_DATE(x))", + }, + ) + self.validate_all( + "UNIX_TIMESTAMP(x)", + write={ + "duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))", + "presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %H:%i:%S'))", + "hive": "UNIX_TIMESTAMP(x)", + "spark": "UNIX_TIMESTAMP(x)", + "": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')", + }, + ) + + for unit in ("DAY", "MONTH", "YEAR"): + self.validate_all( + f"{unit}(x)", + write={ + "duckdb": f"{unit}(CAST(x AS DATE))", + "presto": f"{unit}(CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE))", + "hive": f"{unit}(TO_DATE(x))", + "spark": f"{unit}(TO_DATE(x))", + }, + ) + + def test_order_by(self): + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) + + def test_hive(self): + self.validate_all( + "PERCENTILE(x, 0.5)", + write={ + "duckdb": "QUANTILE(x, 0.5)", + "presto": "APPROX_PERCENTILE(x, 0.5)", + "hive": "PERCENTILE(x, 0.5)", + "spark": "PERCENTILE(x, 0.5)", + }, + ) + self.validate_all( + "APPROX_COUNT_DISTINCT(a)", + write={ + "duckdb": "APPROX_COUNT_DISTINCT(a)", + "presto": "APPROX_DISTINCT(a)", + "hive": "APPROX_COUNT_DISTINCT(a)", + "spark": "APPROX_COUNT_DISTINCT(a)", + }, + ) + self.validate_all( + "ARRAY_CONTAINS(x, 1)", + write={ + "duckdb": "ARRAY_CONTAINS(x, 1)", + "presto": "CONTAINS(x, 1)", + "hive": "ARRAY_CONTAINS(x, 1)", + "spark": "ARRAY_CONTAINS(x, 1)", + }, + ) + self.validate_all( + "SIZE(x)", + write={ + "duckdb": "ARRAY_LENGTH(x)", + "presto": "CARDINALITY(x)", + "hive": "SIZE(x)", + "spark": "SIZE(x)", + }, + ) + self.validate_all( + "LOCATE('a', x)", + write={ + "duckdb": "STRPOS(x, 'a')", + "presto": "STRPOS(x, 'a')", + "hive": "LOCATE('a', x)", + "spark": "LOCATE('a', x)", + }, + ) + self.validate_all( + "LOCATE('a', x, 3)", + write={ + "duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1", + "hive": "LOCATE('a', x, 3)", + "spark": "LOCATE('a', x, 3)", + }, + ) + self.validate_all( + "INITCAP('new york')", + write={ + "duckdb": "INITCAP('new york')", + "presto": "REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2]))", + "hive": "INITCAP('new york')", + "spark": "INITCAP('new york')", + }, + ) + self.validate_all( + "SELECT * FROM x TABLESAMPLE(10) y", + write={ + "presto": "SELECT * FROM x AS y TABLESAMPLE(10)", + "hive": "SELECT * FROM x TABLESAMPLE(10) AS y", + "spark": "SELECT * FROM x TABLESAMPLE(10) AS y", + }, + ) + self.validate_all( + "SELECT SORT_ARRAY(x)", + write={ + "duckdb": "SELECT ARRAY_SORT(x)", + "presto": "SELECT ARRAY_SORT(x)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT SORT_ARRAY(x)", + }, + ) + self.validate_all( + "SELECT SORT_ARRAY(x, FALSE)", + read={ + "duckdb": "SELECT ARRAY_REVERSE_SORT(x)", + "spark": "SELECT SORT_ARRAY(x, FALSE)", + }, + write={ + "duckdb": "SELECT ARRAY_REVERSE_SORT(x)", + "presto": "SELECT ARRAY_SORT(x, (a, b) -> CASE WHEN a < b THEN 1 WHEN a > b THEN -1 ELSE 0 END)", + "hive": "SELECT SORT_ARRAY(x, FALSE)", + "spark": "SELECT SORT_ARRAY(x, FALSE)", + }, + ) + self.validate_all( + "GET_JSON_OBJECT(x, '$.name')", + write={ + "presto": "JSON_EXTRACT_SCALAR(x, '$.name')", + "hive": "GET_JSON_OBJECT(x, '$.name')", + "spark": "GET_JSON_OBJECT(x, '$.name')", + }, + ) + self.validate_all( + "MAP(a, b, c, d)", + write={ + "duckdb": "MAP(LIST_VALUE(a, c), LIST_VALUE(b, d))", + "presto": "MAP(ARRAY[a, c], ARRAY[b, d])", + "hive": "MAP(a, b, c, d)", + "spark": "MAP_FROM_ARRAYS(ARRAY(a, c), ARRAY(b, d))", + }, + ) + self.validate_all( + "MAP(a, b)", + write={ + "duckdb": "MAP(LIST_VALUE(a), LIST_VALUE(b))", + "presto": "MAP(ARRAY[a], ARRAY[b])", + "hive": "MAP(a, b)", + "spark": "MAP_FROM_ARRAYS(ARRAY(a), ARRAY(b))", + }, + ) + self.validate_all( + "LOG(10)", + write={ + "duckdb": "LN(10)", + "presto": "LN(10)", + "hive": "LN(10)", + "spark": "LN(10)", + }, + ) + self.validate_all( + "LOG(10, 2)", + write={ + "duckdb": "LOG(10, 2)", + "presto": "LOG(10, 2)", + "hive": "LOG(10, 2)", + "spark": "LOG(10, 2)", + }, + ) + self.validate_all( + 'ds = "2020-01-01"', + write={ + "duckdb": "ds = '2020-01-01'", + "presto": "ds = '2020-01-01'", + "hive": "ds = '2020-01-01'", + "spark": "ds = '2020-01-01'", + }, + ) + self.validate_all( + "ds = \"1''2\"", + write={ + "duckdb": "ds = '1''''2'", + "presto": "ds = '1''''2'", + "hive": "ds = '1\\'\\'2'", + "spark": "ds = '1\\'\\'2'", + }, + ) + self.validate_all( + "x == 1", + write={ + "duckdb": "x = 1", + "presto": "x = 1", + "hive": "x = 1", + "spark": "x = 1", + }, + ) + self.validate_all( + "x div y", + write={ + "duckdb": "CAST(x / y AS INT)", + "presto": "CAST(x / y AS INTEGER)", + "hive": "CAST(x / y AS INT)", + "spark": "CAST(x / y AS INT)", + }, + ) + self.validate_all( + "COLLECT_LIST(x)", + read={ + "presto": "ARRAY_AGG(x)", + }, + write={ + "duckdb": "ARRAY_AGG(x)", + "presto": "ARRAY_AGG(x)", + "hive": "COLLECT_LIST(x)", + "spark": "COLLECT_LIST(x)", + }, + ) + self.validate_all( + "COLLECT_SET(x)", + read={ + "presto": "SET_AGG(x)", + }, + write={ + "presto": "SET_AGG(x)", + "hive": "COLLECT_SET(x)", + "spark": "COLLECT_SET(x)", + }, + ) + self.validate_all( + "SELECT * FROM x TABLESAMPLE(1) AS foo", + read={ + "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", + }, + write={ + "presto": "SELECT * FROM x AS foo TABLESAMPLE(1)", + "hive": "SELECT * FROM x TABLESAMPLE(1) AS foo", + "spark": "SELECT * FROM x TABLESAMPLE(1) AS foo", + }, + ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py new file mode 100644 index 0000000..ee0c5f5 --- /dev/null +++ b/tests/dialects/test_mysql.py @@ -0,0 +1,79 @@ +from tests.dialects.test_dialect import Validator + + +class TestMySQL(Validator): + dialect = "mysql" + + def test_ddl(self): + self.validate_all( + "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", + write={ + "mysql": "CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x'", + "spark": "CREATE TABLE z (a INT) COMMENT 'x'", + }, + ) + + def test_identity(self): + self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") + + def test_introducers(self): + self.validate_all( + "_utf8mb4 'hola'", + read={ + "mysql": "_utf8mb4'hola'", + }, + write={ + "mysql": "_utf8mb4 'hola'", + }, + ) + + def test_binary_literal(self): + self.validate_all( + "SELECT 0xCC", + write={ + "mysql": "SELECT b'11001100'", + "spark": "SELECT X'11001100'", + }, + ) + self.validate_all( + "SELECT 0xz", + write={ + "mysql": "SELECT `0xz`", + }, + ) + self.validate_all( + "SELECT 0XCC", + write={ + "mysql": "SELECT 0 AS XCC", + }, + ) + + def test_string_literals(self): + self.validate_all( + 'SELECT "2021-01-01" + INTERVAL 1 MONTH', + write={ + "mysql": "SELECT '2021-01-01' + INTERVAL 1 MONTH", + }, + ) + + def test_convert(self): + self.validate_all( + "CONVERT(x USING latin1)", + write={ + "mysql": "CAST(x AS CHAR CHARACTER SET latin1)", + }, + ) + self.validate_all( + "CAST(x AS CHAR CHARACTER SET latin1)", + write={ + "mysql": "CAST(x AS CHAR CHARACTER SET latin1)", + }, + ) + + def test_hash_comments(self): + self.validate_all( + "SELECT 1 # arbitrary content,,, until end-of-line", + write={ + "mysql": "SELECT 1", + }, + ) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py new file mode 100644 index 0000000..15dbfd0 --- /dev/null +++ b/tests/dialects/test_postgres.py @@ -0,0 +1,93 @@ +from sqlglot import ParseError, transpile +from tests.dialects.test_dialect import Validator + + +class TestPostgres(Validator): + dialect = "postgres" + + def test_ddl(self): + self.validate_all( + "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)", + write={ + "postgres": "CREATE TABLE products (product_no INT UNIQUE, name TEXT, price DECIMAL)" + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)", + write={ + "postgres": "CREATE TABLE products (product_no INT CONSTRAINT must_be_different UNIQUE, name TEXT CONSTRAINT present NOT NULL, price DECIMAL)" + }, + ) + self.validate_all( + "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))", + write={ + "postgres": "CREATE TABLE products (product_no INT, name TEXT, price DECIMAL, UNIQUE (product_no, name))" + }, + ) + self.validate_all( + "CREATE TABLE products (" + "product_no INT UNIQUE," + " name TEXT," + " price DECIMAL CHECK (price > 0)," + " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," + " CHECK (product_no > 1)," + " CONSTRAINT valid_discount CHECK (price > discounted_price))", + write={ + "postgres": "CREATE TABLE products (" + "product_no INT UNIQUE," + " name TEXT," + " price DECIMAL CHECK (price > 0)," + " discounted_price DECIMAL CONSTRAINT positive_discount CHECK (discounted_price > 0)," + " CHECK (product_no > 1)," + " CONSTRAINT valid_discount CHECK (price > discounted_price))" + }, + ) + + with self.assertRaises(ParseError): + transpile( + "CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres" + ) + with self.assertRaises(ParseError): + transpile( + "CREATE TABLE products (price DECIMAL, CHECK price > 1)", + read="postgres", + ) + + def test_postgres(self): + self.validate_all( + "CREATE TABLE x (a INT SERIAL)", + read={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, + write={"sqlite": "CREATE TABLE x (a INTEGER AUTOINCREMENT)"}, + ) + self.validate_all( + "CREATE TABLE x (a UUID, b BYTEA)", + write={ + "presto": "CREATE TABLE x (a UUID, b VARBINARY)", + "hive": "CREATE TABLE x (a UUID, b BINARY)", + "spark": "CREATE TABLE x (a UUID, b BINARY)", + }, + ) + self.validate_all( + "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS 1 PRECEDING)", + write={ + "postgres": "SELECT SUM(x) OVER (PARTITION BY a ORDER BY d ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)", + }, + ) + self.validate_all( + "SELECT * FROM x FETCH 1 ROW", + write={ + "postgres": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "presto": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "hive": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + "spark": "SELECT * FROM x FETCH FIRST 1 ROWS ONLY", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + }, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py new file mode 100644 index 0000000..eb9aa5c --- /dev/null +++ b/tests/dialects/test_presto.py @@ -0,0 +1,422 @@ +from sqlglot import UnsupportedError +from tests.dialects.test_dialect import Validator + + +class TestPresto(Validator): + dialect = "presto" + + def test_cast(self): + self.validate_all( + "CAST(a AS ARRAY(INT))", + write={ + "bigquery": "CAST(a AS ARRAY<INT64>)", + "duckdb": "CAST(a AS ARRAY<INT>)", + "presto": "CAST(a AS ARRAY(INTEGER))", + "spark": "CAST(a AS ARRAY<INT>)", + }, + ) + self.validate_all( + "CAST(a AS VARCHAR)", + write={ + "bigquery": "CAST(a AS STRING)", + "duckdb": "CAST(a AS TEXT)", + "presto": "CAST(a AS VARCHAR)", + "spark": "CAST(a AS STRING)", + }, + ) + self.validate_all( + "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", + write={ + "bigquery": "CAST([1, 2] AS ARRAY<INT64>)", + "duckdb": "CAST(LIST_VALUE(1, 2) AS ARRAY<BIGINT>)", + "presto": "CAST(ARRAY[1, 2] AS ARRAY(BIGINT))", + "spark": "CAST(ARRAY(1, 2) AS ARRAY<LONG>)", + }, + ) + self.validate_all( + "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INT,INT))", + write={ + "bigquery": "CAST(MAP([1], [1]) AS MAP<INT64, INT64>)", + "duckdb": "CAST(MAP(LIST_VALUE(1), LIST_VALUE(1)) AS MAP<INT, INT>)", + "presto": "CAST(MAP(ARRAY[1], ARRAY[1]) AS MAP(INTEGER, INTEGER))", + "hive": "CAST(MAP(1, 1) AS MAP<INT, INT>)", + "spark": "CAST(MAP_FROM_ARRAYS(ARRAY(1), ARRAY(1)) AS MAP<INT, INT>)", + }, + ) + self.validate_all( + "CAST(MAP(ARRAY['a','b','c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INT)))", + write={ + "bigquery": "CAST(MAP(['a', 'b', 'c'], [[1], [2], [3]]) AS MAP<STRING, ARRAY<INT64>>)", + "duckdb": "CAST(MAP(LIST_VALUE('a', 'b', 'c'), LIST_VALUE(LIST_VALUE(1), LIST_VALUE(2), LIST_VALUE(3))) AS MAP<TEXT, ARRAY<INT>>)", + "presto": "CAST(MAP(ARRAY['a', 'b', 'c'], ARRAY[ARRAY[1], ARRAY[2], ARRAY[3]]) AS MAP(VARCHAR, ARRAY(INTEGER)))", + "hive": "CAST(MAP('a', ARRAY(1), 'b', ARRAY(2), 'c', ARRAY(3)) AS MAP<STRING, ARRAY<INT>>)", + "spark": "CAST(MAP_FROM_ARRAYS(ARRAY('a', 'b', 'c'), ARRAY(ARRAY(1), ARRAY(2), ARRAY(3))) AS MAP<STRING, ARRAY<INT>>)", + }, + ) + self.validate_all( + "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", + write={ + "bigquery": "CAST(x AS TIMESTAMPTZ(9))", + "duckdb": "CAST(x AS TIMESTAMPTZ(9))", + "presto": "CAST(x AS TIMESTAMP(9) WITH TIME ZONE)", + "hive": "CAST(x AS TIMESTAMPTZ(9))", + "spark": "CAST(x AS TIMESTAMPTZ(9))", + }, + ) + + def test_regex(self): + self.validate_all( + "REGEXP_LIKE(a, 'x')", + write={ + "duckdb": "REGEXP_MATCHES(a, 'x')", + "presto": "REGEXP_LIKE(a, 'x')", + "hive": "a RLIKE 'x'", + "spark": "a RLIKE 'x'", + }, + ) + self.validate_all( + "SPLIT(x, 'a.')", + write={ + "duckdb": "STR_SPLIT(x, 'a.')", + "presto": "SPLIT(x, 'a.')", + "hive": "SPLIT(x, CONCAT('\\\\Q', 'a.'))", + "spark": "SPLIT(x, CONCAT('\\\\Q', 'a.'))", + }, + ) + self.validate_all( + "REGEXP_SPLIT(x, 'a.')", + write={ + "duckdb": "STR_SPLIT_REGEX(x, 'a.')", + "presto": "REGEXP_SPLIT(x, 'a.')", + "hive": "SPLIT(x, 'a.')", + "spark": "SPLIT(x, 'a.')", + }, + ) + self.validate_all( + "CARDINALITY(x)", + write={ + "duckdb": "ARRAY_LENGTH(x)", + "presto": "CARDINALITY(x)", + "hive": "SIZE(x)", + "spark": "SIZE(x)", + }, + ) + + def test_time(self): + self.validate_all( + "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + write={ + "duckdb": "STRFTIME(x, '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT(x, '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", + "spark": "DATE_FORMAT(x, 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", + write={ + "duckdb": "STRPTIME(x, '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%S')", + "hive": "CAST(x AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "DATE_PARSE(x, '%Y-%m-%d')", + write={ + "duckdb": "STRPTIME(x, '%Y-%m-%d')", + "presto": "DATE_PARSE(x, '%Y-%m-%d')", + "hive": "CAST(x AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(x, 'yyyy-MM-dd')", + }, + ) + self.validate_all( + "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", + write={ + "duckdb": "STRPTIME(SUBSTR(x, 1, 10), '%Y-%m-%d')", + "presto": "DATE_PARSE(SUBSTR(x, 1, 10), '%Y-%m-%d')", + "hive": "CAST(SUBSTR(x, 1, 10) AS TIMESTAMP)", + "spark": "TO_TIMESTAMP(SUBSTR(x, 1, 10), 'yyyy-MM-dd')", + }, + ) + self.validate_all( + "FROM_UNIXTIME(x)", + write={ + "duckdb": "TO_TIMESTAMP(CAST(x AS BIGINT))", + "presto": "FROM_UNIXTIME(x)", + "hive": "FROM_UNIXTIME(x)", + "spark": "FROM_UNIXTIME(x)", + }, + ) + self.validate_all( + "TO_UNIXTIME(x)", + write={ + "duckdb": "EPOCH(x)", + "presto": "TO_UNIXTIME(x)", + "hive": "UNIX_TIMESTAMP(x)", + "spark": "UNIX_TIMESTAMP(x)", + }, + ) + self.validate_all( + "DATE_ADD('day', 1, x)", + write={ + "duckdb": "x + INTERVAL 1 day", + "presto": "DATE_ADD('day', 1, x)", + "hive": "DATE_ADD(x, 1)", + "spark": "DATE_ADD(x, 1)", + }, + ) + + def test_ddl(self): + self.validate_all( + "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET', X = '1', Z = '2') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET TBLPROPERTIES ('X' = '1', 'Z' = '2') AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + write={ + "presto": "CREATE TABLE x (w VARCHAR, y INTEGER, z INTEGER) WITH (PARTITIONED_BY = ARRAY['y', 'z'])", + "hive": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + "spark": "CREATE TABLE x (w STRING) PARTITIONED BY (y INT, z INT)", + }, + ) + self.validate_all( + "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", + write={ + "presto": "CREATE TABLE x WITH (bucket_by = ARRAY['y'], bucket_count = 64) AS SELECT 1 AS y", + "hive": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", + "spark": "CREATE TABLE x TBLPROPERTIES ('bucket_by' = ARRAY('y'), 'bucket_count' = 64) AS SELECT 1 AS y", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + write={ + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRING>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRING>)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + write={ + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)", + }, + ) + + self.validate( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + read="presto", + write="presto", + ) + + def test_quotes(self): + self.validate_all( + "''''", + write={ + "duckdb": "''''", + "presto": "''''", + "hive": "'\\''", + "spark": "'\\''", + }, + ) + self.validate_all( + "'x'", + write={ + "duckdb": "'x'", + "presto": "'x'", + "hive": "'x'", + "spark": "'x'", + }, + ) + self.validate_all( + "'''x'''", + write={ + "duckdb": "'''x'''", + "presto": "'''x'''", + "hive": "'\\'x\\''", + "spark": "'\\'x\\''", + }, + ) + self.validate_all( + "'''x'", + write={ + "duckdb": "'''x'", + "presto": "'''x'", + "hive": "'\\'x'", + "spark": "'\\'x'", + }, + ) + self.validate_all( + "x IN ('a', 'a''b')", + write={ + "duckdb": "x IN ('a', 'a''b')", + "presto": "x IN ('a', 'a''b')", + "hive": "x IN ('a', 'a\\'b')", + "spark": "x IN ('a', 'a\\'b')", + }, + ) + + def test_unnest(self): + self.validate_all( + "SELECT a FROM x CROSS JOIN UNNEST(ARRAY(y)) AS t (a)", + write={ + "presto": "SELECT a FROM x CROSS JOIN UNNEST(ARRAY[y]) AS t(a)", + "hive": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + "spark": "SELECT a FROM x LATERAL VIEW EXPLODE(ARRAY(y)) t AS a", + }, + ) + + def test_presto(self): + self.validate_all( + 'SELECT a."b" FROM "foo"', + write={ + "duckdb": 'SELECT a."b" FROM "foo"', + "presto": 'SELECT a."b" FROM "foo"', + "spark": "SELECT a.`b` FROM `foo`", + }, + ) + self.validate_all( + "SELECT ARRAY[1, 2]", + write={ + "bigquery": "SELECT [1, 2]", + "duckdb": "SELECT LIST_VALUE(1, 2)", + "presto": "SELECT ARRAY[1, 2]", + "spark": "SELECT ARRAY(1, 2)", + }, + ) + self.validate_all( + "SELECT APPROX_DISTINCT(a) FROM foo", + write={ + "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "presto": "SELECT APPROX_DISTINCT(a) FROM foo", + "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + }, + ) + self.validate_all( + "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + write={ + "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + }, + ) + self.validate_all( + "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + write={ + "presto": "SELECT APPROX_DISTINCT(a, 0.1) FROM foo", + "hive": UnsupportedError, + "spark": UnsupportedError, + }, + ) + self.validate_all( + "SELECT JSON_EXTRACT(x, '$.name')", + write={ + "presto": "SELECT JSON_EXTRACT(x, '$.name')", + "hive": "SELECT GET_JSON_OBJECT(x, '$.name')", + "spark": "SELECT GET_JSON_OBJECT(x, '$.name')", + }, + ) + self.validate_all( + "SELECT JSON_EXTRACT_SCALAR(x, '$.name')", + write={ + "presto": "SELECT JSON_EXTRACT_SCALAR(x, '$.name')", + "hive": "SELECT GET_JSON_OBJECT(x, '$.name')", + "spark": "SELECT GET_JSON_OBJECT(x, '$.name')", + }, + ) + self.validate_all( + "'\u6bdb'", + write={ + "presto": "'\u6bdb'", + "hive": "'\u6bdb'", + "spark": "'\u6bdb'", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x, (left, right) -> -1)", + write={ + "duckdb": "SELECT ARRAY_SORT(x)", + "presto": "SELECT ARRAY_SORT(x, (left, right) -> -1)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT ARRAY_SORT(x, (left, right) -> -1)", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x)", + write={ + "presto": "SELECT ARRAY_SORT(x)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT ARRAY_SORT(x)", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x, (left, right) -> -1)", + write={ + "hive": UnsupportedError, + }, + ) + self.validate_all( + "MAP(a, b)", + write={ + "hive": UnsupportedError, + "spark": "MAP_FROM_ARRAYS(a, b)", + }, + ) + self.validate_all( + "MAP(ARRAY(a, b), ARRAY(c, d))", + write={ + "hive": "MAP(a, c, b, d)", + "presto": "MAP(ARRAY[a, b], ARRAY[c, d])", + "spark": "MAP_FROM_ARRAYS(ARRAY(a, b), ARRAY(c, d))", + }, + ) + self.validate_all( + "MAP(ARRAY('a'), ARRAY('b'))", + write={ + "hive": "MAP('a', 'b')", + "presto": "MAP(ARRAY['a'], ARRAY['b'])", + "spark": "MAP_FROM_ARRAYS(ARRAY('a'), ARRAY('b'))", + }, + ) + self.validate_all( + "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x", + write={ + "bigquery": "SELECT * FROM UNNEST(['7', '14'])", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x", + "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x", + }, + ) + self.validate_all( + "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)", + write={ + "bigquery": "SELECT * FROM UNNEST(['7', '14']) AS y", + "presto": "SELECT * FROM UNNEST(ARRAY['7', '14']) AS x(y)", + "hive": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)", + "spark": "SELECT * FROM UNNEST(ARRAY('7', '14')) AS x(y)", + }, + ) + self.validate_all( + "WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n+1 FROM t WHERE n < 100 ) SELECT sum(n) FROM t", + write={ + "presto": "WITH RECURSIVE t(n) AS (VALUES (1) UNION ALL SELECT n + 1 FROM t WHERE n < 100) SELECT SUM(n) FROM t", + "spark": UnsupportedError, + }, + ) diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py new file mode 100644 index 0000000..62f78e1 --- /dev/null +++ b/tests/dialects/test_snowflake.py @@ -0,0 +1,145 @@ +from sqlglot import UnsupportedError +from tests.dialects.test_dialect import Validator + + +class TestSnowflake(Validator): + dialect = "snowflake" + + def test_snowflake(self): + self.validate_all( + 'x:a:"b c"', + write={ + "duckdb": "x['a']['b c']", + "hive": "x['a']['b c']", + "presto": "x['a']['b c']", + "snowflake": "x['a']['b c']", + "spark": "x['a']['b c']", + }, + ) + self.validate_all( + "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + write={ + "bigquery": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a NULLS LAST LIMIT 10", + "snowflake": "SELECT a FROM test WHERE a = 1 GROUP BY a HAVING a = 2 QUALIFY z ORDER BY a LIMIT 10", + }, + ) + self.validate_all( + "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z) = 1", + write={ + "bigquery": "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z NULLS LAST) = 1", + "snowflake": "SELECT a FROM test AS t QUALIFY ROW_NUMBER() OVER (PARTITION BY a ORDER BY Z) = 1", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP(1659981729)", + write={ + "bigquery": "SELECT UNIX_TO_TIME(1659981729)", + "snowflake": "SELECT TO_TIMESTAMP(1659981729)", + "spark": "SELECT FROM_UNIXTIME(1659981729)", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP(1659981729000, 3)", + write={ + "bigquery": "SELECT UNIX_TO_TIME(1659981729000, 'millis')", + "snowflake": "SELECT TO_TIMESTAMP(1659981729000, 3)", + "spark": "SELECT TIMESTAMP_MILLIS(1659981729000)", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP('1659981729')", + write={ + "bigquery": "SELECT UNIX_TO_TIME('1659981729')", + "snowflake": "SELECT TO_TIMESTAMP('1659981729')", + "spark": "SELECT FROM_UNIXTIME('1659981729')", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP(1659981729000000000, 9)", + write={ + "bigquery": "SELECT UNIX_TO_TIME(1659981729000000000, 'micros')", + "snowflake": "SELECT TO_TIMESTAMP(1659981729000000000, 9)", + "spark": "SELECT TIMESTAMP_MICROS(1659981729000000000)", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP('2013-04-05 01:02:03')", + write={ + "bigquery": "SELECT STR_TO_TIME('2013-04-05 01:02:03', '%Y-%m-%d %H:%M:%S')", + "snowflake": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-mm-dd hh24:mi:ss')", + "spark": "SELECT TO_TIMESTAMP('2013-04-05 01:02:03', 'yyyy-MM-dd HH:mm:ss')", + }, + ) + self.validate_all( + "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + read={ + "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "duckdb": "SELECT STRPTIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + }, + write={ + "bigquery": "SELECT STR_TO_TIME('04/05/2013 01:02:03', '%m/%d/%Y %H:%M:%S')", + "snowflake": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'mm/dd/yyyy hh24:mi:ss')", + "spark": "SELECT TO_TIMESTAMP('04/05/2013 01:02:03', 'MM/dd/yyyy HH:mm:ss')", + }, + ) + self.validate_all( + "SELECT IFF(TRUE, 'true', 'false')", + write={ + "snowflake": "SELECT IFF(TRUE, 'true', 'false')", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname NULLS LAST", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname", + }, + ) + self.validate_all( + "SELECT ARRAY_AGG(DISTINCT a)", + write={ + "spark": "SELECT COLLECT_LIST(DISTINCT a)", + "snowflake": "SELECT ARRAY_AGG(DISTINCT a)", + }, + ) + self.validate_all( + "SELECT * FROM a INTERSECT ALL SELECT * FROM b", + write={ + "snowflake": UnsupportedError, + }, + ) + self.validate_all( + "SELECT * FROM a EXCEPT ALL SELECT * FROM b", + write={ + "snowflake": UnsupportedError, + }, + ) + self.validate_all( + "SELECT ARRAY_UNION_AGG(a)", + write={ + "snowflake": "SELECT ARRAY_UNION_AGG(a)", + }, + ) + self.validate_all( + "SELECT NVL2(a, b, c)", + write={ + "snowflake": "SELECT NVL2(a, b, c)", + }, + ) + self.validate_all( + "SELECT $$a$$", + write={ + "snowflake": "SELECT 'a'", + }, + ) + self.validate_all( + r"SELECT $$a ' \ \t \x21 z $ $$", + write={ + "snowflake": r"SELECT 'a \' \\ \\t \\x21 z $ '", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py new file mode 100644 index 0000000..8794fed --- /dev/null +++ b/tests/dialects/test_spark.py @@ -0,0 +1,226 @@ +from tests.dialects.test_dialect import Validator + + +class TestSpark(Validator): + dialect = "spark" + + def test_ddl(self): + self.validate_all( + "CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:string>)", + write={ + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b VARCHAR))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRING>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRING>)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a struct<struct_col_a:int, struct_col_b:struct<nested_col_a:string, nested_col_b:string>>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT64, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", + "presto": "CREATE TABLE db.example_table (col_a ROW(struct_col_a INTEGER, struct_col_b ROW(nested_col_a VARCHAR, nested_col_b VARCHAR)))", + "hive": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a INT, struct_col_b STRUCT<nested_col_a STRING, nested_col_b STRING>>)", + "spark": "CREATE TABLE db.example_table (col_a STRUCT<struct_col_a: INT, struct_col_b: STRUCT<nested_col_a: STRING, nested_col_b: STRING>>)", + }, + ) + self.validate_all( + "CREATE TABLE db.example_table (col_a array<int>, col_b array<array<int>>)", + write={ + "bigquery": "CREATE TABLE db.example_table (col_a ARRAY<INT64>, col_b ARRAY<ARRAY<INT64>>)", + "presto": "CREATE TABLE db.example_table (col_a ARRAY(INTEGER), col_b ARRAY(ARRAY(INTEGER)))", + "hive": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)", + "spark": "CREATE TABLE db.example_table (col_a ARRAY<INT>, col_b ARRAY<ARRAY<INT>>)", + }, + ) + self.validate_all( + "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + write={ + "presto": "CREATE TABLE x WITH (TABLE_FORMAT = 'ICEBERG', PARTITIONED_BY = ARRAY['MONTHS'])", + "hive": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + "spark": "CREATE TABLE x USING ICEBERG PARTITIONED BY (MONTHS(y)) LOCATION 's3://z'", + }, + ) + self.validate_all( + "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (FORMAT = 'PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test STORED AS PARQUET AS SELECT 1", + }, + ) + self.validate_all( + "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + write={ + "presto": "CREATE TABLE test WITH (TABLE_FORMAT = 'ICEBERG', FORMAT = 'PARQUET') AS SELECT 1", + "hive": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + "spark": "CREATE TABLE test USING ICEBERG STORED AS PARQUET AS SELECT 1", + }, + ) + self.validate_all( + """CREATE TABLE blah (col_a INT) COMMENT "Test comment: blah" PARTITIONED BY (date STRING) STORED AS ICEBERG TBLPROPERTIES('x' = '1')""", + write={ + "presto": """CREATE TABLE blah ( + col_a INTEGER, + date VARCHAR +) +COMMENT='Test comment: blah' +WITH ( + PARTITIONED_BY = ARRAY['date'], + FORMAT = 'ICEBERG', + x = '1' +)""", + "hive": """CREATE TABLE blah ( + col_a INT +) +COMMENT 'Test comment: blah' +PARTITIONED BY ( + date STRING +) +STORED AS ICEBERG +TBLPROPERTIES ( + 'x' = '1' +)""", + "spark": """CREATE TABLE blah ( + col_a INT +) +COMMENT 'Test comment: blah' +PARTITIONED BY ( + date STRING +) +STORED AS ICEBERG +TBLPROPERTIES ( + 'x' = '1' +)""", + }, + pretty=True, + ) + + def test_to_date(self): + self.validate_all( + "TO_DATE(x, 'yyyy-MM-dd')", + write={ + "duckdb": "CAST(x AS DATE)", + "hive": "TO_DATE(x)", + "presto": "CAST(SUBSTR(CAST(x AS VARCHAR), 1, 10) AS DATE)", + "spark": "TO_DATE(x)", + }, + ) + self.validate_all( + "TO_DATE(x, 'yyyy')", + write={ + "duckdb": "CAST(STRPTIME(x, '%Y') AS DATE)", + "hive": "TO_DATE(x, 'yyyy')", + "presto": "CAST(DATE_PARSE(x, '%Y') AS DATE)", + "spark": "TO_DATE(x, 'yyyy')", + }, + ) + + def test_hint(self): + self.validate_all( + "SELECT /*+ COALESCE(3) */ * FROM x", + write={ + "spark": "SELECT /*+ COALESCE(3) */ * FROM x", + }, + ) + self.validate_all( + "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", + write={ + "spark": "SELECT /*+ COALESCE(3), REPARTITION(1) */ * FROM x", + }, + ) + + def test_spark(self): + self.validate_all( + "ARRAY_SORT(x, (left, right) -> -1)", + write={ + "duckdb": "ARRAY_SORT(x)", + "presto": "ARRAY_SORT(x, (left, right) -> -1)", + "hive": "SORT_ARRAY(x)", + "spark": "ARRAY_SORT(x, (left, right) -> -1)", + }, + ) + self.validate_all( + "ARRAY(0, 1, 2)", + write={ + "bigquery": "[0, 1, 2]", + "duckdb": "LIST_VALUE(0, 1, 2)", + "presto": "ARRAY[0, 1, 2]", + "hive": "ARRAY(0, 1, 2)", + "spark": "ARRAY(0, 1, 2)", + }, + ) + + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "clickhouse": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "duckdb": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "postgres": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", + "presto": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname, lname NULLS FIRST", + "hive": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "snowflake": "SELECT fname, lname, age FROM person ORDER BY age DESC, fname, lname NULLS FIRST", + }, + ) + self.validate_all( + "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + write={ + "duckdb": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "presto": "SELECT APPROX_DISTINCT(a) FROM foo", + "hive": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + "spark": "SELECT APPROX_COUNT_DISTINCT(a) FROM foo", + }, + ) + self.validate_all( + "MONTH('2021-03-01')", + write={ + "duckdb": "MONTH(CAST('2021-03-01' AS DATE))", + "presto": "MONTH(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))", + "hive": "MONTH(TO_DATE('2021-03-01'))", + "spark": "MONTH(TO_DATE('2021-03-01'))", + }, + ) + self.validate_all( + "YEAR('2021-03-01')", + write={ + "duckdb": "YEAR(CAST('2021-03-01' AS DATE))", + "presto": "YEAR(CAST(SUBSTR(CAST('2021-03-01' AS VARCHAR), 1, 10) AS DATE))", + "hive": "YEAR(TO_DATE('2021-03-01'))", + "spark": "YEAR(TO_DATE('2021-03-01'))", + }, + ) + self.validate_all( + "'\u6bdb'", + write={ + "duckdb": "'毛'", + "presto": "'毛'", + "hive": "'毛'", + "spark": "'毛'", + }, + ) + self.validate_all( + "SELECT LEFT(x, 2), RIGHT(x, 2)", + write={ + "duckdb": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + "presto": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + "hive": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + "spark": "SELECT SUBSTRING(x, 1, 2), SUBSTRING(x, LENGTH(x) - 2 + 1, 2)", + }, + ) + self.validate_all( + "MAP_FROM_ARRAYS(ARRAY(1), c)", + write={ + "duckdb": "MAP(LIST_VALUE(1), c)", + "presto": "MAP(ARRAY[1], c)", + "hive": "MAP(ARRAY(1), c)", + "spark": "MAP_FROM_ARRAYS(ARRAY(1), c)", + }, + ) + self.validate_all( + "SELECT ARRAY_SORT(x)", + write={ + "duckdb": "SELECT ARRAY_SORT(x)", + "presto": "SELECT ARRAY_SORT(x)", + "hive": "SELECT SORT_ARRAY(x)", + "spark": "SELECT ARRAY_SORT(x)", + }, + ) diff --git a/tests/dialects/test_sqlite.py b/tests/dialects/test_sqlite.py new file mode 100644 index 0000000..a0576de --- /dev/null +++ b/tests/dialects/test_sqlite.py @@ -0,0 +1,72 @@ +from tests.dialects.test_dialect import Validator + + +class TestSQLite(Validator): + dialect = "sqlite" + + def test_ddl(self): + self.validate_all( + """ + CREATE TABLE "Track" + ( + CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"), + FOREIGN KEY ("AlbumId") REFERENCES "Album" ("AlbumId") + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT, + FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT + ) + """, + write={ + "sqlite": """CREATE TABLE "Track" ( + CONSTRAINT "PK_Track" FOREIGN KEY ("TrackId"), + FOREIGN KEY ("AlbumId") REFERENCES "Album"("AlbumId") ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ("AlbumId") ON DELETE CASCADE ON UPDATE RESTRICT, + FOREIGN KEY ("AlbumId") ON DELETE SET NULL ON UPDATE SET DEFAULT +)""", + }, + pretty=True, + ) + self.validate_all( + "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", + read={ + "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", + }, + write={ + "sqlite": "CREATE TABLE z (a INTEGER UNIQUE PRIMARY KEY AUTOINCREMENT)", + "mysql": "CREATE TABLE z (a INT UNIQUE PRIMARY KEY AUTO_INCREMENT)", + }, + ) + self.validate_all( + """CREATE TABLE "x" ("Name" NVARCHAR(200) NOT NULL)""", + write={ + "sqlite": """CREATE TABLE "x" ("Name" TEXT(200) NOT NULL)""", + "mysql": "CREATE TABLE `x` (`Name` VARCHAR(200) NOT NULL)", + }, + ) + + def test_sqlite(self): + self.validate_all( + "SELECT CAST([a].[b] AS SMALLINT) FROM foo", + write={ + "sqlite": 'SELECT CAST("a"."b" AS INTEGER) FROM foo', + "spark": "SELECT CAST(`a`.`b` AS SHORT) FROM foo", + }, + ) + self.validate_all( + "EDITDIST3(col1, col2)", + read={ + "sqlite": "EDITDIST3(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + }, + write={ + "sqlite": "EDITDIST3(col1, col2)", + "spark": "LEVENSHTEIN(col1, col2)", + }, + ) + self.validate_all( + "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname", + write={ + "spark": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + "sqlite": "SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname NULLS LAST, lname", + }, + ) diff --git a/tests/dialects/test_starrocks.py b/tests/dialects/test_starrocks.py new file mode 100644 index 0000000..1fe1a57 --- /dev/null +++ b/tests/dialects/test_starrocks.py @@ -0,0 +1,8 @@ +from tests.dialects.test_dialect import Validator + + +class TestMySQL(Validator): + dialect = "starrocks" + + def test_identity(self): + self.validate_identity("SELECT CAST(`a`.`b` AS INT) FROM foo") diff --git a/tests/dialects/test_tableau.py b/tests/dialects/test_tableau.py new file mode 100644 index 0000000..0f612dd --- /dev/null +++ b/tests/dialects/test_tableau.py @@ -0,0 +1,62 @@ +from tests.dialects.test_dialect import Validator + + +class TestTableau(Validator): + dialect = "tableau" + + def test_tableau(self): + self.validate_all( + "IF x = 'a' THEN y ELSE NULL END", + read={ + "presto": "IF(x = 'a', y, NULL)", + }, + write={ + "presto": "IF(x = 'a', y, NULL)", + "hive": "IF(x = 'a', y, NULL)", + "tableau": "IF x = 'a' THEN y ELSE NULL END", + }, + ) + self.validate_all( + "IFNULL(a, 0)", + read={ + "presto": "COALESCE(a, 0)", + }, + write={ + "presto": "COALESCE(a, 0)", + "hive": "COALESCE(a, 0)", + "tableau": "IFNULL(a, 0)", + }, + ) + self.validate_all( + "COUNTD(a)", + read={ + "presto": "COUNT(DISTINCT a)", + }, + write={ + "presto": "COUNT(DISTINCT a)", + "hive": "COUNT(DISTINCT a)", + "tableau": "COUNTD(a)", + }, + ) + self.validate_all( + "COUNTD((a))", + read={ + "presto": "COUNT(DISTINCT(a))", + }, + write={ + "presto": "COUNT(DISTINCT (a))", + "hive": "COUNT(DISTINCT (a))", + "tableau": "COUNTD((a))", + }, + ) + self.validate_all( + "COUNT(a)", + read={ + "presto": "COUNT(a)", + }, + write={ + "presto": "COUNT(a)", + "hive": "COUNT(a)", + "tableau": "COUNT(a)", + }, + ) diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql new file mode 100644 index 0000000..40f11a2 --- /dev/null +++ b/tests/fixtures/identity.sql @@ -0,0 +1,514 @@ +SUM(1) +SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y +1 +1.0 +1E2 +1E+2 +1E-2 +1.1E10 +1.12e-10 +-11.023E7 * 3 +(1 * 2) / (3 - 5) +((TRUE)) +'' +'''' +'x' +'\x' +"x" +"" +x +x % 1 +x < 1 +x <= 1 +x > 1 +x >= 1 +x <> 1 +x = y OR x > 1 +x & 1 +x | 1 +x ^ 1 +~x +x << 1 +x >> 1 +x >> 1 | 1 & 1 ^ 1 +x || y +1 - -1 +dec.x + y +a.filter +a.b.c +a.b.c.d +a.b.c.d.e +a.b.c.d.e[0] +a.b.c.d.e[0].f +a[0][0].b.c[1].d.e.f[1][1] +a[0].b[1] +a[0].b.c['d'] +a.b.C() +a['x'].b.C() +a.B() +a['x'].C() +int.x +map.x +x IN (-1, 1) +x IN ('a', 'a''a') +x IN ((1)) +x BETWEEN -1 AND 1 +x BETWEEN 'a' || b AND 'c' || d +NOT x IS NULL +x IS TRUE +x IS FALSE +time +zone +ARRAY<TEXT> +CURRENT_DATE +CURRENT_DATE('UTC') +CURRENT_DATE AT TIME ZONE 'UTC' +CURRENT_DATE AT TIME ZONE zone_column +CURRENT_DATE AT TIME ZONE 'UTC' AT TIME ZONE 'Asia/Tokio' +ARRAY() +ARRAY(1, 2) +ARRAY_CONTAINS(x, 1) +EXTRACT(x FROM y) +EXTRACT(DATE FROM y) +CONCAT_WS('-', 'a', 'b') +CONCAT_WS('-', 'a', 'b', 'c') +POSEXPLODE("x") AS ("a", "b") +POSEXPLODE("x") AS ("a", "b", "c") +STR_POSITION(x, 'a') +STR_POSITION(x, 'a', 3) +SPLIT(SPLIT(referrer, 'utm_source=')[OFFSET(1)], "&")[OFFSET(0)] +x[ORDINAL(1)][SAFE_OFFSET(2)] +x LIKE SUBSTR('abc', 1, 1) +x LIKE y +x LIKE a.y +x LIKE '%y%' +x ILIKE '%y%' +x LIKE '%y%' ESCAPE '\' +x ILIKE '%y%' ESCAPE '\' +1 AS escape +INTERVAL '1' day +INTERVAL '1' month +INTERVAL '1 day' +INTERVAL 2 months +INTERVAL 1 + 3 days +TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) +DATETIME_DIFF(CURRENT_DATE, 1, DAY) +QUANTILE(x, 0.5) +REGEXP_REPLACE('new york', '(\w)(\w*)', x -> UPPER(x[1]) || LOWER(x[2])) +REGEXP_LIKE('new york', '.') +REGEXP_SPLIT('new york', '.') +SPLIT('new york', '.') +X((y AS z)).1 +(x AS y, y AS z) +REPLACE(1) +DATE(x) = DATE(y) +TIMESTAMP(DATE(x)) +TIMESTAMP_TRUNC(COALESCE(time_field, CURRENT_TIMESTAMP()), DAY) +COUNT(DISTINCT CASE WHEN DATE_TRUNC(DATE(time_field), isoweek) = DATE_TRUNC(DATE(time_field2), isoweek) THEN report_id ELSE NULL END) +x[y - 1] +CASE WHEN SUM(x) > 3 THEN 1 END OVER (PARTITION BY x) +SUM(ROW() OVER (PARTITION BY x)) +SUM(ROW() OVER (PARTITION BY x + 1)) +SUM(ROW() OVER (PARTITION BY x AND y)) +(ROW() OVER ()) +CASE WHEN (x > 1) THEN 1 ELSE 0 END +CASE (1) WHEN 1 THEN 1 ELSE 0 END +CASE 1 WHEN 1 THEN 1 ELSE 0 END +x AT TIME ZONE 'UTC' +CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo' +SET x = 1 +SET -v +ADD JAR s3://bucket +ADD JARS s3://bucket, c +ADD FILE s3://file +ADD FILES s3://file, s3://a +ADD ARCHIVE s3://file +ADD ARCHIVES s3://file, s3://a +BEGIN IMMEDIATE TRANSACTION +COMMIT +USE db +NOT 1 +NOT NOT 1 +SELECT * FROM test +SELECT *, 1 FROM test +SELECT * FROM a.b +SELECT * FROM a.b.c +SELECT * FROM table +SELECT 1 +SELECT 1 FROM test +SELECT * FROM a, b, (SELECT 1) AS c +SELECT a FROM test +SELECT 1 AS filter +SELECT SUM(x) AS filter +SELECT 1 AS range FROM test +SELECT 1 AS count FROM test +SELECT 1 AS comment FROM test +SELECT 1 AS numeric FROM test +SELECT 1 AS number FROM test +SELECT t.count +SELECT DISTINCT x FROM test +SELECT DISTINCT x, y FROM test +SELECT DISTINCT TIMESTAMP_TRUNC(time_field, MONTH) AS time_value FROM "table" +SELECT DISTINCT ON (x) x, y FROM z +SELECT DISTINCT ON (x, y + 1) * FROM z +SELECT DISTINCT ON (x.y) * FROM z +SELECT top.x +SELECT TIMESTAMP(DATE_TRUNC(DATE(time_field), MONTH)) AS time_value FROM "table" +SELECT GREATEST((3 + 1), LEAST(3, 4)) +SELECT TRANSFORM(a, b -> b) AS x +SELECT AGGREGATE(a, (a, b) -> a + b) AS x +SELECT SUM(DISTINCT x) +SELECT SUM(x IGNORE NULLS) AS x +SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x +SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x +SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x +SELECT LAG(x) OVER (ORDER BY y) AS x +SELECT LEAD(a) OVER (ORDER BY b) AS a +SELECT LEAD(a, 1) OVER (PARTITION BY a ORDER BY a) AS x +SELECT LEAD(a, 1, b) OVER (PARTITION BY a ORDER BY a) AS x +SELECT X((a, b) -> a + b, z -> z) AS x +SELECT X(a -> "a" + ("z" - 1)) +SELECT EXISTS(ARRAY(2, 3), x -> x % 2 = 0) +SELECT test.* FROM test +SELECT a AS b FROM test +SELECT "a"."b" FROM "a" +SELECT "a".b FROM a +SELECT a.b FROM "a" +SELECT a.b FROM a +SELECT '"hi' AS x FROM x +SELECT 1 AS "|sum" FROM x +SELECT '\"hi' AS x FROM x +SELECT 1 AS b FROM test +SELECT 1 AS "b" FROM test +SELECT 1 + 1 FROM test +SELECT 1 - 1 FROM test +SELECT 1 * 1 FROM test +SELECT 1 % 1 FROM test +SELECT 1 / 1 FROM test +SELECT 1 < 2 FROM test +SELECT 1 <= 2 FROM test +SELECT 1 > 2 FROM test +SELECT 1 >= 2 FROM test +SELECT 1 <> 2 FROM test +SELECT JSON_EXTRACT(x, '$.name') +SELECT JSON_EXTRACT_SCALAR(x, '$.name') +SELECT x LIKE '%x%' FROM test +SELECT * FROM test LIMIT 100 +SELECT * FROM test LIMIT 100 OFFSET 200 +SELECT * FROM test FETCH FIRST 1 ROWS ONLY +SELECT * FROM test FETCH NEXT 1 ROWS ONLY +SELECT (1 > 2) AS x FROM test +SELECT NOT (1 > 2) FROM test +SELECT 1 + 2 AS x FROM test +SELECT a, b, 1 < 1 FROM test +SELECT a FROM test WHERE NOT FALSE +SELECT a FROM test WHERE a = 1 +SELECT a FROM test WHERE a = 1 AND b = 2 +SELECT a FROM test WHERE a IN (SELECT b FROM z) +SELECT a FROM test WHERE a IN ((SELECT 1), 2) +SELECT * FROM x WHERE y IN ((SELECT 1) EXCEPT (SELECT 2)) +SELECT * FROM x WHERE y IN (SELECT 1 UNION SELECT 2) +SELECT * FROM x WHERE y IN ((SELECT 1 UNION SELECT 2)) +SELECT * FROM x WHERE y IN (WITH z AS (SELECT 1) SELECT * FROM z) +SELECT a FROM test WHERE (a > 1) +SELECT a FROM test WHERE a > (SELECT 1 FROM x GROUP BY y) +SELECT a FROM test WHERE EXISTS(SELECT 1) +SELECT a FROM test WHERE EXISTS(SELECT * FROM x UNION SELECT * FROM Y) OR TRUE +SELECT a FROM test WHERE TRUE OR NOT EXISTS(SELECT * FROM x) +SELECT a AS any, b AS some, c AS all, d AS exists FROM test WHERE a = ANY (SELECT 1) +SELECT a FROM test WHERE a > ALL (SELECT 1) +SELECT a FROM test WHERE (a, b) IN (SELECT 1, 2) +SELECT a FROM test ORDER BY a +SELECT a FROM test ORDER BY a, b +SELECT x FROM tests ORDER BY a DESC, b DESC, c +SELECT a FROM test ORDER BY a > 1 +SELECT * FROM test ORDER BY DATE DESC, TIMESTAMP DESC +SELECT * FROM test DISTRIBUTE BY y SORT BY x DESC ORDER BY l +SELECT * FROM test CLUSTER BY y +SELECT * FROM test CLUSTER BY y +SELECT * FROM test WHERE RAND() <= 0.1 DISTRIBUTE BY RAND() SORT BY RAND() +SELECT a, b FROM test GROUP BY 1 +SELECT a, b FROM test GROUP BY a +SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2 +SELECT a, b FROM test WHERE a = 1 GROUP BY a HAVING a = 2 ORDER BY a +SELECT a, b FROM test WHERE a = 1 GROUP BY CASE 1 WHEN 1 THEN 1 END +SELECT a FROM test GROUP BY GROUPING SETS (()) +SELECT a FROM test GROUP BY GROUPING SETS (x, ()) +SELECT a FROM test GROUP BY GROUPING SETS (x, (x, y), (x, y, z), q) +SELECT a FROM test GROUP BY CUBE (x) +SELECT a FROM test GROUP BY ROLLUP (x) +SELECT a FROM test GROUP BY CUBE (x) ROLLUP (x, y, z) +SELECT CASE WHEN a < b THEN 1 WHEN a < c THEN 2 ELSE 3 END FROM test +SELECT CASE 1 WHEN 1 THEN 1 ELSE 2 END +SELECT CASE 1 WHEN 1 THEN MAP('a', 'b') ELSE MAP('b', 'c') END['a'] +SELECT CASE 1 + 2 WHEN 1 THEN 1 ELSE 2 END +SELECT CASE TEST(1) + x[0] WHEN 1 THEN 1 ELSE 2 END +SELECT CASE x[0] WHEN 1 THEN 1 ELSE 2 END +SELECT CASE a.b WHEN 1 THEN 1 ELSE 2 END +SELECT CASE CASE x > 1 WHEN TRUE THEN 1 END WHEN 1 THEN 1 ELSE 2 END +SELECT a FROM (SELECT a FROM test) AS x +SELECT a FROM (SELECT a FROM (SELECT a FROM test) AS y) AS x +SELECT a FROM test WHERE a IN (1, 2, 3) OR b BETWEEN 1 AND 4 +SELECT a FROM test AS x TABLESAMPLE(BUCKET 1 OUT OF 5) +SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5) +SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON x) +SELECT a FROM test TABLESAMPLE(BUCKET 1 OUT OF 5 ON RAND()) +SELECT a FROM test TABLESAMPLE(0.1 PERCENT) +SELECT a FROM test TABLESAMPLE(100) +SELECT a FROM test TABLESAMPLE(100 ROWS) +SELECT a FROM test TABLESAMPLE BERNOULLI (50) +SELECT a FROM test TABLESAMPLE SYSTEM (75) +SELECT ABS(a) FROM test +SELECT AVG(a) FROM test +SELECT CEIL(a) FROM test +SELECT COUNT(a) FROM test +SELECT COUNT(1) FROM test +SELECT COUNT(*) FROM test +SELECT COUNT(DISTINCT a) FROM test +SELECT EXP(a) FROM test +SELECT FLOOR(a) FROM test +SELECT FIRST(a) FROM test +SELECT GREATEST(a, b, c) FROM test +SELECT LAST(a) FROM test +SELECT LN(a) FROM test +SELECT LOG10(a) FROM test +SELECT MAX(a) FROM test +SELECT MIN(a) FROM test +SELECT POWER(a, 2) FROM test +SELECT QUANTILE(a, 0.95) FROM test +SELECT ROUND(a) FROM test +SELECT ROUND(a, 2) FROM test +SELECT SUM(a) FROM test +SELECT SQRT(a) FROM test +SELECT STDDEV(a) FROM test +SELECT STDDEV_POP(a) FROM test +SELECT STDDEV_SAMP(a) FROM test +SELECT VARIANCE(a) FROM test +SELECT VARIANCE_POP(a) FROM test +SELECT CAST(a AS INT) FROM test +SELECT CAST(a AS DATETIME) FROM test +SELECT CAST(a AS VARCHAR) FROM test +SELECT CAST(a < 1 AS INT) FROM test +SELECT CAST(a IS NULL AS INT) FROM test +SELECT COUNT(CAST(1 < 2 AS INT)) FROM test +SELECT COUNT(CASE WHEN CAST(1 < 2 AS BOOLEAN) THEN 1 END) FROM test +SELECT CAST(a AS DECIMAL) FROM test +SELECT CAST(a AS DECIMAL(1)) FROM test +SELECT CAST(a AS DECIMAL(1, 2)) FROM test +SELECT CAST(a AS MAP<INT, INT>) FROM test +SELECT CAST(a AS TIMESTAMP) FROM test +SELECT CAST(a AS DATE) FROM test +SELECT CAST(a AS ARRAY<INT>) FROM test +SELECT TRY_CAST(a AS INT) FROM test +SELECT COALESCE(a, b, c) FROM test +SELECT IFNULL(a, b) FROM test +SELECT ANY_VALUE(a) FROM test +SELECT 1 FROM a JOIN b ON a.x = b.x +SELECT 1 FROM a JOIN b AS c ON a.x = b.x +SELECT 1 FROM a INNER JOIN b ON a.x = b.x +SELECT 1 FROM a LEFT JOIN b ON a.x = b.x +SELECT 1 FROM a RIGHT JOIN b ON a.x = b.x +SELECT 1 FROM a CROSS JOIN b ON a.x = b.x +SELECT 1 FROM a JOIN b USING (x) +SELECT 1 FROM a JOIN b USING (x, y, z) +SELECT 1 FROM a JOIN (SELECT a FROM c) AS b ON a.x = b.x AND a.x < 2 +SELECT 1 FROM a UNION SELECT 2 FROM b +SELECT 1 FROM a UNION ALL SELECT 2 FROM b +SELECT 1 FROM a JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar +SELECT 1 FROM a LEFT JOIN b ON a.foo = b.bar JOIN c ON a.foo = c.bar +SELECT 1 FROM a LEFT INNER JOIN b ON a.foo = b.bar +SELECT 1 FROM a LEFT OUTER JOIN b ON a.foo = b.bar +SELECT 1 FROM a OUTER JOIN b ON a.foo = b.bar +SELECT 1 FROM a FULL JOIN b ON a.foo = b.bar +SELECT 1 UNION ALL SELECT 2 +SELECT 1 EXCEPT SELECT 2 +SELECT 1 EXCEPT SELECT 2 +SELECT 1 INTERSECT SELECT 2 +SELECT 1 INTERSECT SELECT 2 +SELECT 1 AS delete, 2 AS alter +SELECT * FROM (x) +SELECT * FROM ((x)) +SELECT * FROM ((SELECT 1)) +SELECT * FROM (SELECT 1) AS x +SELECT * FROM (SELECT 1 UNION SELECT 2) AS x +SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS x +SELECT * FROM (SELECT 1 UNION ALL SELECT 2) +SELECT * FROM ((SELECT 1) AS a UNION ALL (SELECT 2) AS b) +SELECT * FROM ((SELECT 1) AS a(b)) +SELECT * FROM x AS y(a, b) +SELECT * EXCEPT (a, b) +SELECT * REPLACE (a AS b, b AS C) +SELECT * REPLACE (a + 1 AS b, b AS C) +SELECT * EXCEPT (a, b) REPLACE (a AS b, b AS C) +SELECT a.* EXCEPT (a, b), b.* REPLACE (a AS b, b AS C) +SELECT zoo, animals FROM (VALUES ('oakland', ARRAY('a', 'b')), ('sf', ARRAY('b', 'c'))) AS t(zoo, animals) +WITH a AS (SELECT 1) SELECT 1 UNION ALL SELECT 2 +WITH a AS (SELECT 1) SELECT 1 UNION SELECT 2 +WITH a AS (SELECT 1) SELECT 1 INTERSECT SELECT 2 +WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 +WITH a AS (SELECT 1) SELECT 1 EXCEPT SELECT 2 +(SELECT 1) UNION (SELECT 2) +(SELECT 1) UNION SELECT 2 +SELECT 1 UNION (SELECT 2) +(SELECT 1) ORDER BY x LIMIT 1 OFFSET 1 +(SELECT 1 UNION SELECT 2) UNION (SELECT 2 UNION ALL SELECT 3) +(SELECT 1 UNION SELECT 2) ORDER BY x LIMIT 1 OFFSET 1 +(SELECT 1 UNION SELECT 2) CLUSTER BY y DESC +(SELECT 1 UNION SELECT 2) SORT BY z +(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z +(SELECT 1 UNION SELECT 2) DISTRIBUTE BY z SORT BY x +SELECT 1 UNION (SELECT 2) ORDER BY x +(SELECT 1) UNION SELECT 2 ORDER BY x +SELECT * FROM (((SELECT 1) UNION SELECT 2) ORDER BY x LIMIT 1 OFFSET 1) +SELECT * FROM ((SELECT 1 AS x) CROSS JOIN (SELECT 2 AS y)) AS z +((SELECT 1) EXCEPT (SELECT 2)) +VALUES (1) UNION SELECT * FROM x +WITH a AS (SELECT 1) SELECT a.* FROM a +WITH a AS (SELECT 1), b AS (SELECT 2) SELECT a.*, b.* FROM a CROSS JOIN b +WITH a AS (WITH b AS (SELECT 1 AS x) SELECT b.x FROM b) SELECT a.x FROM a +WITH RECURSIVE T(n) AS (VALUES (1) UNION ALL SELECT n + 1 FROM t WHERE n < 100) SELECT SUM(n) FROM t +WITH RECURSIVE T(n, m) AS (VALUES (1, 2) UNION ALL SELECT n + 1, n + 2 FROM t) SELECT SUM(n) FROM t +WITH baz AS (SELECT 1 AS col) UPDATE bar SET cid = baz.col1 FROM baz +SELECT * FROM (WITH y AS (SELECT 1 AS z) SELECT z FROM y) AS x +SELECT RANK() OVER () FROM x +SELECT RANK() OVER () AS y FROM x +SELECT RANK() OVER (PARTITION BY a) FROM x +SELECT RANK() OVER (PARTITION BY a, b) FROM x +SELECT RANK() OVER (ORDER BY a) FROM x +SELECT RANK() OVER (ORDER BY a, b) FROM x +SELECT RANK() OVER (PARTITION BY a ORDER BY a) FROM x +SELECT RANK() OVER (PARTITION BY a, b ORDER BY a, b DESC) FROM x +SELECT SUM(x) OVER (PARTITION BY a) AS y FROM x +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2' DAYS FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND UNBOUNDED FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND PRECEDING) +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 AND 3) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND 3) +SELECT SUM(x) OVER (PARTITION BY a RANGE BETWEEN 1 FOLLOWING AND UNBOUNDED FOLLOWING) +SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x) AS y +SELECT LISTAGG(x) WITHIN GROUP (ORDER BY x DESC) +SELECT SUM(x) FILTER(WHERE x > 1) +SELECT SUM(x) FILTER(WHERE x > 1) OVER (ORDER BY y) +SELECT COUNT(DISTINCT a) OVER (PARTITION BY c ORDER BY d ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) +SELECT a['1'], b[0], x.c[0], "x".d['1'] FROM x +SELECT ARRAY(1, 2, 3) FROM x +SELECT ARRAY(ARRAY(1), ARRAY(2)) FROM x +SELECT MAP[ARRAY(1), ARRAY(2)] FROM x +SELECT MAP(ARRAY(1), ARRAY(2)) FROM x +SELECT MAX(ARRAY(1, 2, 3)) FROM x +SELECT ARRAY(ARRAY(0))[0][0] FROM x +SELECT MAP[ARRAY('x'), ARRAY(0)]['x'] FROM x +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) AS score +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) t AS score +SELECT student, score FROM tests LATERAL VIEW EXPLODE(scores) t AS score, name +SELECT student, score FROM tests LATERAL VIEW OUTER EXPLODE(scores) t AS score, name +SELECT tf.* FROM (SELECT 0) AS t LATERAL VIEW STACK(1, 2) tf +SELECT tf.* FROM (SELECT 0) AS t LATERAL VIEW STACK(1, 2) tf AS col0, col1, col2 +SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(score) +SELECT student, score FROM tests CROSS JOIN UNNEST(scores) AS t(a, b) +SELECT student, score FROM tests CROSS JOIN UNNEST(scores) WITH ORDINALITY AS t(a, b) +SELECT student, score FROM tests CROSS JOIN UNNEST(x.scores) AS t(score) +SELECT student, score FROM tests CROSS JOIN UNNEST(ARRAY(x.scores)) AS t(score) +CREATE TABLE a.b AS SELECT 1 +CREATE TABLE a.b AS SELECT a FROM a.c +CREATE TABLE IF NOT EXISTS x AS SELECT a FROM d +CREATE TEMPORARY TABLE x AS SELECT a FROM d +CREATE TEMPORARY TABLE IF NOT EXISTS x AS SELECT a FROM d +CREATE VIEW x AS SELECT a FROM b +CREATE VIEW IF NOT EXISTS x AS SELECT a FROM b +CREATE OR REPLACE VIEW x AS SELECT * +CREATE OR REPLACE TEMPORARY VIEW x AS SELECT * +CREATE TEMPORARY VIEW x AS SELECT a FROM d +CREATE TEMPORARY VIEW IF NOT EXISTS x AS SELECT a FROM d +CREATE TEMPORARY VIEW x AS WITH y AS (SELECT 1) SELECT * FROM y +CREATE TABLE z (a INT, b VARCHAR, c VARCHAR(100), d DECIMAL(5, 3)) +CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3)) +CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3)) +CREATE TABLE z (a INT(11) DEFAULT UUID()) +CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') +CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) +CREATE TABLE z (a INT(11) NOT NULL COLLATE utf8_bin AUTO_INCREMENT) +CREATE TABLE z (a INT, PRIMARY KEY(a)) +CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' +CREATE TABLE z (a INT) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' +CREATE TABLE z (a INT DEFAULT NULL, PRIMARY KEY(a)) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='x' +CREATE TABLE z WITH (FORMAT='parquet') AS SELECT 1 +CREATE TABLE z WITH (FORMAT='ORC', x = '2') AS SELECT 1 +CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='parquet') AS SELECT 1 +CREATE TABLE z WITH (TABLE_FORMAT='iceberg', FORMAT='ORC', x = '2') AS SELECT 1 +CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT, y INT)) +CREATE TABLE z (z INT) WITH (PARTITIONED_BY=(x INT)) AS SELECT 1 +CREATE TABLE z AS (WITH cte AS (SELECT 1) SELECT * FROM cte) +CREATE TABLE z AS ((WITH cte AS (SELECT 1) SELECT * FROM cte)) +CREATE TABLE z (a INT UNIQUE) +CREATE TABLE z (a INT AUTO_INCREMENT) +CREATE TABLE z (a INT UNIQUE AUTO_INCREMENT) +CREATE TEMPORARY FUNCTION f +CREATE TEMPORARY FUNCTION f AS 'g' +CREATE FUNCTION f +CREATE FUNCTION f AS 'g' +CREATE INDEX abc ON t (a) +CREATE INDEX abc ON t (a, b, b) +CREATE UNIQUE INDEX abc ON t (a, b, b) +CREATE UNIQUE INDEX IF NOT EXISTS my_idx ON tbl (a, b) +CACHE TABLE x +CACHE LAZY TABLE x +CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') +CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS SELECT 1 +CACHE LAZY TABLE x OPTIONS('storageLevel' = 'value') AS WITH a AS (SELECT 1) SELECT a.* FROM a +CACHE LAZY TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a +CACHE TABLE x AS WITH a AS (SELECT 1) SELECT a.* FROM a +CALL catalog.system.iceberg_procedure_name(named_arg_1 => 'arg_1', named_arg_2 => 'arg_2') +INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y +INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y +INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y +INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y +ALTER TYPE electronic_mail RENAME TO email +ANALYZE a.y +DELETE FROM x WHERE y > 1 +DELETE FROM y +DROP TABLE a +DROP TABLE a.b +DROP TABLE IF EXISTS a +DROP TABLE IF EXISTS a.b +DROP VIEW a +DROP VIEW a.b +DROP VIEW IF EXISTS a +DROP VIEW IF EXISTS a.b +SHOW TABLES +EXPLAIN SELECT * FROM x +INSERT INTO x SELECT * FROM y +INSERT INTO x (SELECT * FROM y) +INSERT INTO x WITH y AS (SELECT 1) SELECT * FROM y +INSERT INTO x.z IF EXISTS SELECT * FROM y +INSERT INTO x VALUES (1, 'a', 2.0) +INSERT INTO x VALUES (1, 'a', 2.0), (1, 'a', 3.0), (X(), y[1], z.x) +INSERT INTO y (a, b, c) SELECT a, b, c FROM x +INSERT OVERWRITE TABLE x IF EXISTS SELECT * FROM y +INSERT OVERWRITE TABLE a.b IF EXISTS SELECT * FROM y +SELECT 1 FROM PARQUET_SCAN('/x/y/*') AS y +UNCACHE TABLE x +UNCACHE TABLE IF EXISTS x +UPDATE tbl_name SET foo = 123 +UPDATE tbl_name SET foo = 123, bar = 345 +UPDATE db.tbl_name SET foo = 123 WHERE tbl_name.bar = 234 +UPDATE db.tbl_name SET foo = 123, foo_1 = 234 WHERE tbl_name.bar = 234 +TRUNCATE TABLE x +OPTIMIZE TABLE y +WITH a AS (SELECT 1) INSERT INTO b SELECT * FROM a +WITH a AS (SELECT * FROM b) UPDATE a SET col = 1 +WITH a AS (SELECT * FROM b) CREATE TABLE b AS SELECT * FROM a +WITH a AS (SELECT * FROM b) DELETE FROM a +WITH a AS (SELECT * FROM b) CACHE TABLE a +SELECT ? AS ? FROM x WHERE b BETWEEN ? AND ? GROUP BY ?, 1 LIMIT ? +WITH a AS ((SELECT b.foo AS foo, b.bar AS bar FROM b) UNION ALL (SELECT c.foo AS foo, c.bar AS bar FROM c)) SELECT * FROM a +WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 1 AS b)) SELECT * FROM a +SELECT (WITH x AS (SELECT 1 AS y) SELECT * FROM x) AS z diff --git a/tests/fixtures/optimizer/eliminate_subqueries.sql b/tests/fixtures/optimizer/eliminate_subqueries.sql new file mode 100644 index 0000000..aae5f2a --- /dev/null +++ b/tests/fixtures/optimizer/eliminate_subqueries.sql @@ -0,0 +1,42 @@ +SELECT 1 AS x, 2 AS y +UNION ALL +SELECT 1 AS x, 2 AS y; +WITH _e_0 AS ( + SELECT + 1 AS x, + 2 AS y +) +SELECT + * +FROM _e_0 +UNION ALL +SELECT + * +FROM _e_0; + +SELECT x.id +FROM ( + SELECT * + FROM x AS x + JOIN y AS y + ON x.id = y.id +) AS x +JOIN ( + SELECT * + FROM x AS x + JOIN y AS y + ON x.id = y.id +) AS y +ON x.id = y.id; +WITH _e_0 AS ( + SELECT + * + FROM x AS x + JOIN y AS y + ON x.id = y.id +) +SELECT + x.id +FROM "_e_0" AS x +JOIN "_e_0" AS y + ON x.id = y.id; diff --git a/tests/fixtures/optimizer/expand_multi_table_selects.sql b/tests/fixtures/optimizer/expand_multi_table_selects.sql new file mode 100644 index 0000000..a5a4664 --- /dev/null +++ b/tests/fixtures/optimizer/expand_multi_table_selects.sql @@ -0,0 +1,11 @@ +-------------------------------------- +-- Multi Table Selects +-------------------------------------- +SELECT * FROM x AS x, y AS y WHERE x.a = y.a; +SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a = y.a; + +SELECT * FROM x AS x, y AS y WHERE x.a = y.a AND x.a = 1 and y.b = 1; +SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a = y.a AND x.a = 1 AND y.b = 1; + +SELECT * FROM x AS x, y AS y WHERE x.a > y.a; +SELECT * FROM x AS x CROSS JOIN y AS y WHERE x.a > y.a; diff --git a/tests/fixtures/optimizer/isolate_table_selects.sql b/tests/fixtures/optimizer/isolate_table_selects.sql new file mode 100644 index 0000000..3b9a938 --- /dev/null +++ b/tests/fixtures/optimizer/isolate_table_selects.sql @@ -0,0 +1,20 @@ +SELECT * FROM x AS x, y AS y2; +SELECT * FROM (SELECT * FROM x AS x) AS x, (SELECT * FROM y AS y) AS y2; + +SELECT * FROM x AS x WHERE x = 1; +SELECT * FROM x AS x WHERE x = 1; + +SELECT * FROM x AS x JOIN y AS y; +SELECT * FROM (SELECT * FROM x AS x) AS x JOIN (SELECT * FROM y AS y) AS y; + +SELECT * FROM (SELECT 1) AS x JOIN y AS y; +SELECT * FROM (SELECT 1) AS x JOIN (SELECT * FROM y AS y) AS y; + +SELECT * FROM x AS x JOIN (SELECT * FROM y) AS y; +SELECT * FROM (SELECT * FROM x AS x) AS x JOIN (SELECT * FROM y) AS y; + +WITH y AS (SELECT *) SELECT * FROM x AS x; +WITH y AS (SELECT *) SELECT * FROM x AS x; + +WITH y AS (SELECT * FROM y AS y2 JOIN x AS z2) SELECT * FROM x AS x JOIN y as y; +WITH y AS (SELECT * FROM (SELECT * FROM y AS y) AS y2 JOIN (SELECT * FROM x AS x) AS z2) SELECT * FROM (SELECT * FROM x AS x) AS x JOIN y AS y; diff --git a/tests/fixtures/optimizer/normalize.sql b/tests/fixtures/optimizer/normalize.sql new file mode 100644 index 0000000..a84fadf --- /dev/null +++ b/tests/fixtures/optimizer/normalize.sql @@ -0,0 +1,41 @@ +(A OR B) AND (B OR C) AND (E OR F); +(A OR B) AND (B OR C) AND (E OR F); + +(A AND B) OR (B AND C AND D); +(A OR C) AND (A OR D) AND B; + +(A OR B) AND (A OR C) AND (A OR D) AND (B OR C) AND (B OR D) AND B; +(A OR C) AND (A OR D) AND B; + +(A AND E) OR (B AND C) OR (D AND (E OR F)); +(A OR B OR D) AND (A OR C OR D) AND (B OR D OR E) AND (B OR E OR F) AND (C OR D OR E) AND (C OR E OR F); + +(A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q); +(A AND B AND C AND D AND E AND F AND G) OR (H AND I AND J AND K AND L AND M AND N) OR (O AND P AND Q); + +NOT NOT NOT (A OR B); +NOT A AND NOT B; + +A OR B; +A OR B; + +A AND (B AND C); +A AND B AND C; + +A OR (B AND C); +(A OR B) AND (A OR C); + +(A AND B) OR C; +(A OR C) AND (B OR C); + +A OR (B OR (C AND D)); +(A OR B OR C) AND (A OR B OR D); + +A OR ((((B OR C) AND (B OR D)) OR C) AND (((B OR C) AND (B OR D)) OR D)); +(A OR B OR C) AND (A OR B OR D); + +(A AND B) OR (C AND D); +(A OR C) AND (A OR D) AND (B OR C) AND (B OR D); + +(A AND B) OR (C OR (D AND E)); +(A OR C OR D) AND (A OR C OR E) AND (B OR C OR D) AND (B OR C OR E); diff --git a/tests/fixtures/optimizer/optimize_joins.sql b/tests/fixtures/optimizer/optimize_joins.sql new file mode 100644 index 0000000..b64544e --- /dev/null +++ b/tests/fixtures/optimizer/optimize_joins.sql @@ -0,0 +1,20 @@ +SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a AND y.a = z.a; +SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = 1 AND y.a = z.a; + +SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a; +SELECT * FROM x JOIN y ON y.a = 1 JOIN z ON x.a = z.a; + +SELECT * FROM x CROSS JOIN y JOIN z ON x.a = z.a AND y.a = z.a; +SELECT * FROM x JOIN z ON x.a = z.a AND TRUE JOIN y ON y.a = z.a; + +SELECT * FROM x LEFT JOIN y ON y.a = 1 JOIN z ON x.a = z.a AND y.a = z.a; +SELECT * FROM x JOIN z ON x.a = z.a AND TRUE LEFT JOIN y ON y.a = 1 AND y.a = z.a; + +SELECT * FROM x INNER JOIN z; +SELECT * FROM x JOIN z; + +SELECT * FROM x LEFT OUTER JOIN z; +SELECT * FROM x LEFT JOIN z; + +SELECT * FROM x CROSS JOIN z; +SELECT * FROM x CROSS JOIN z; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql new file mode 100644 index 0000000..f7bbdda --- /dev/null +++ b/tests/fixtures/optimizer/optimizer.sql @@ -0,0 +1,148 @@ +SELECT a, m FROM z LATERAL VIEW EXPLODE([1, 2]) q AS m; +SELECT + "z"."a" AS "a", + "q"."m" AS "m" +FROM ( + SELECT + "z"."a" AS "a" + FROM "z" AS "z" +) AS "z" +LATERAL VIEW +EXPLODE(ARRAY(1, 2)) q AS "m"; + +SELECT x FROM UNNEST([1, 2]) AS q(x, y); +SELECT + "q"."x" AS "x" +FROM UNNEST(ARRAY(1, 2)) AS "q"("x", "y"); + +WITH cte AS ( + ( + SELECT + a + FROM + x + ) + UNION ALL + ( + SELECT + a + FROM + y + ) +) +SELECT + * +FROM + cte; +WITH "cte" AS ( + ( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" + ) + UNION ALL + ( + SELECT + "y"."a" AS "a" + FROM "y" AS "y" + ) +) +SELECT + "cte"."a" AS "a" +FROM "cte"; + +WITH cte1 AS ( + SELECT a + FROM x +), cte2 AS ( + SELECT a + 1 AS a + FROM cte1 +) +SELECT + a +FROM cte1 +UNION ALL +SELECT + a +FROM cte2; +WITH "cte1" AS ( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" +), "cte2" AS ( + SELECT + "cte1"."a" + 1 AS "a" + FROM "cte1" +) +SELECT + "cte1"."a" AS "a" +FROM "cte1" +UNION ALL +SELECT + "cte2"."a" AS "a" +FROM "cte2"; + +SELECT a, SUM(b) +FROM ( + SELECT x.a, y.b + FROM x, y + WHERE (SELECT max(b) FROM y WHERE x.a = y.a) >= 0 AND x.a = y.a +) d +WHERE (TRUE AND TRUE OR 'a' = 'b') AND a > 1 +GROUP BY a; +SELECT + "d"."a" AS "a", + SUM("d"."b") AS "_col_1" +FROM ( + SELECT + "x"."a" AS "a", + "y"."b" AS "b" + FROM ( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" + WHERE + "x"."a" > 1 + ) AS "x" + LEFT JOIN ( + SELECT + MAX("y"."b") AS "_col_0", + "y"."a" AS "_u_1" + FROM "y" AS "y" + GROUP BY + "y"."a" + ) AS "_u_0" + ON "x"."a" = "_u_0"."_u_1" + JOIN ( + SELECT + "y"."a" AS "a", + "y"."b" AS "b" + FROM "y" AS "y" + ) AS "y" + ON "x"."a" = "y"."a" + WHERE + "_u_0"."_col_0" >= 0 + AND NOT "_u_0"."_u_1" IS NULL +) AS "d" +GROUP BY + "d"."a"; + +(SELECT a FROM x) LIMIT 1; +( + SELECT + "x"."a" AS "a" + FROM "x" AS "x" +) +LIMIT 1; + +(SELECT b FROM x UNION SELECT b FROM y) LIMIT 1; +( + SELECT + "x"."b" AS "b" + FROM "x" AS "x" + UNION + SELECT + "y"."b" AS "b" + FROM "y" AS "y" +) +LIMIT 1; diff --git a/tests/fixtures/optimizer/pushdown_predicates.sql b/tests/fixtures/optimizer/pushdown_predicates.sql new file mode 100644 index 0000000..676cb96 --- /dev/null +++ b/tests/fixtures/optimizer/pushdown_predicates.sql @@ -0,0 +1,32 @@ +SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b = 1 AND y.a = 1; +SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON y.a = 1 WHERE TRUE AND TRUE AND TRUE; + +WITH x AS (SELECT y.a FROM y) SELECT * FROM x WHERE x.a = 1; +WITH x AS (SELECT y.a FROM y WHERE y.a = 1) SELECT * FROM x WHERE TRUE; + +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE y.a = 1 OR (x.a = 1 AND x.b = 1); +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = 1 AND x.b = 1) OR y.a = 1; + +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.a; +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a WHERE TRUE; + +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; +SELECT x.a FROM (SELECT * FROM x) AS x JOIN y ON x.a = y.a OR x.a = y.b WHERE (x.a = y.a AND x.a = 1 AND x.b = 1) OR x.a = y.b; + +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1; +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1) AS x WHERE TRUE; + +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x) AS x WHERE x.c = 1 or x.c = 2; +SELECT x.a FROM (SELECT x.a AS a, x.b * 1 AS c FROM x WHERE x.b * 1 = 1 OR x.b * 1 = 2) AS x WHERE TRUE; + +SELECT x.a AS a FROM (SELECT x.a FROM x AS x) AS x JOIN y WHERE x.a = 1 AND x.b = 1 AND (x.c = 1 OR y.c = 1); +SELECT x.a AS a FROM (SELECT x.a FROM x AS x WHERE x.a = 1 AND x.b = 1) AS x JOIN y ON x.c = 1 OR y.c = 1 WHERE TRUE AND TRUE AND (TRUE); + +SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y) AS y ON y.a = 1 AND x.a = y.a; +SELECT x.a FROM x AS x JOIN (SELECT y.a FROM y AS y WHERE y.a = 1) AS y ON x.a = y.a AND TRUE; + +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y) AS y ON y.a = 1 WHERE x.a = 1 AND x.b = 1 AND y.a = x; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE; + +SELECT x.a AS a FROM x AS x CROSS JOIN (SELECT * FROM y AS y) AS y WHERE x.a = 1 AND x.b = 1 AND y.a = x.a AND y.a = 1; +SELECT x.a AS a FROM x AS x JOIN (SELECT * FROM y AS y WHERE y.a = 1) AS y ON y.a = x.a AND TRUE WHERE x.a = 1 AND x.b = 1 AND TRUE AND TRUE; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql new file mode 100644 index 0000000..9deceb6 --- /dev/null +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -0,0 +1,41 @@ +SELECT a FROM (SELECT * FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; +SELECT 1 AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS "_q_0" WHERE "_q_0".b = 2; + +SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; +SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS q; + +SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b; +SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b; + +SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; +SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; + +SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; +SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; + +SELECT a FROM (SELECT DISTINCT a, b FROM x); +SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +WITH y AS (SELECT * FROM x) SELECT a FROM y; +WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y; + +WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q; +WITH z AS (SELECT x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z) SELECT q.b AS b FROM q; + +WITH z AS (SELECT * FROM x) SELECT a FROM z UNION SELECT a FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z; + +SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a); +SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q_0"; + +SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a); +SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0"; diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql new file mode 100644 index 0000000..004c57c --- /dev/null +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -0,0 +1,233 @@ +-------------------------------------- +-- Qualify columns +-------------------------------------- +SELECT a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT a FROM x AS z; +SELECT z.a AS a FROM x AS z; + +SELECT a AS a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT x.a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT x.a AS a FROM x; +SELECT x.a AS a FROM x AS x; + +SELECT a AS b FROM x; +SELECT x.a AS b FROM x AS x; + +SELECT 1, 2 FROM x; +SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x; + +SELECT a + b FROM x; +SELECT x.a + x.b AS "_col_0" FROM x AS x; + +SELECT a + b FROM x; +SELECT x.a + x.b AS "_col_0" FROM x AS x; + +SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; +SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; + +SELECT a AS j, b FROM x ORDER BY j; +SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; + +SELECT a AS j, b FROM x GROUP BY j; +SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a; + +SELECT a, b FROM x GROUP BY 1, 2; +SELECT x.a AS a, x.b AS b FROM x AS x GROUP BY x.a, x.b; + +SELECT a, b FROM x ORDER BY 1, 2; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a, b; + +SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2; +SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); + +SELECT x.a AS c FROM x JOIN y ON x.b = y.b GROUP BY c; +SELECT x.a AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; + +SELECT DATE(x.a) AS d FROM x JOIN y ON x.b = y.b GROUP BY d; +SELECT DATE(x.a) AS d FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY DATE(x.a); + +SELECT a AS a, b FROM x ORDER BY a; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; + +SELECT a, b FROM x ORDER BY a; +SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY a; + +SELECT a FROM x ORDER BY b; +SELECT x.a AS a FROM x AS x ORDER BY x.b; + +# dialect: bigquery +SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b) AS row_num FROM x QUALIFY row_num = 1; +SELECT ROW_NUMBER() OVER (PARTITION BY x.a ORDER BY x.b) AS row_num FROM x AS x QUALIFY row_num = 1; + +# dialect: bigquery +SELECT x.b, x.a FROM x LEFT JOIN y ON x.b = y.b QUALIFY ROW_NUMBER() OVER(PARTITION BY x.b ORDER BY x.a DESC) = 1; +SELECT x.b AS b, x.a AS a FROM x AS x LEFT JOIN y AS y ON x.b = y.b QUALIFY ROW_NUMBER() OVER (PARTITION BY x.b ORDER BY x.a DESC) = 1; + +-------------------------------------- +-- Derived tables +-------------------------------------- +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; + +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y(a); +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; + +SELECT y.c AS c FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS y(c); +SELECT y.c AS c FROM (SELECT x.a AS c, x.b AS b FROM x AS x) AS y; + +SELECT a FROM (SELECT a FROM x AS x) y; +SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; + +SELECT a FROM (SELECT a AS a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +SELECT a FROM (SELECT a FROM (SELECT a FROM x)); +SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1"; + +SELECT x.a FROM x AS x JOIN (SELECT * FROM x); +SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- Joins +-------------------------------------- +SELECT a, c FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT a, c FROM x, y; +SELECT x.a AS a, y.c AS c FROM x AS x, y AS y; + +-------------------------------------- +-- Unions +-------------------------------------- +SELECT a FROM x UNION SELECT a FROM x; +SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; + +SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x; +SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; + +SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- Subqueries +-------------------------------------- +SELECT a FROM x WHERE b IN (SELECT c FROM y); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y); + +SELECT (SELECT c FROM y) FROM x; +SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x; + +SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y)); +SELECT "_q_1".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_1" WHERE "_q_1".a IN (SELECT "_q_0".b AS b FROM (SELECT y.b AS b FROM y AS y) AS "_q_0"); + +-------------------------------------- +-- Correlated subqueries +-------------------------------------- +SELECT a FROM x WHERE b IN (SELECT c FROM y WHERE y.b = x.a); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y WHERE y.b = x.a); + +SELECT a FROM x WHERE b IN (SELECT c FROM y WHERE y.b = a); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y WHERE y.b = x.a); + +SELECT a FROM x WHERE b IN (SELECT b FROM y AS x); +SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT x.b AS b FROM y AS x); + +SELECT a FROM x AS i WHERE b IN (SELECT b FROM y AS j WHERE j.b IN (SELECT c FROM y AS k WHERE k.b = j.b)); +SELECT i.a AS a FROM x AS i WHERE i.b IN (SELECT j.b AS b FROM y AS j WHERE j.b IN (SELECT k.c AS c FROM y AS k WHERE k.b = j.b)); + +# dialect: bigquery +SELECT aa FROM x, UNNEST(a) AS aa; +SELECT aa AS aa FROM x AS x, UNNEST(x.a) AS aa; + +SELECT aa FROM x, UNNEST(a) AS t(aa); +SELECT t.aa AS aa FROM x AS x, UNNEST(x.a) AS t(aa); + +-------------------------------------- +-- Expand * +-------------------------------------- +SELECT * FROM x; +SELECT x.a AS a, x.b AS b FROM x AS x; + +SELECT x.* FROM x; +SELECT x.a AS a, x.b AS b FROM x AS x; + +SELECT * FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT x.* FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT x.*, y.* FROM x JOIN y ON x.b = y.b; +SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT a FROM (SELECT * FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; + +SELECT * FROM (SELECT a FROM x); +SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; + +-------------------------------------- +-- CTEs +-------------------------------------- +WITH z AS (SELECT x.a AS a FROM x) SELECT z.a AS a FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z; + +WITH z(a) AS (SELECT a FROM x) SELECT * FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z; + +WITH z AS (SELECT a FROM x) SELECT * FROM z as q; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT q.a AS a FROM z AS q; + +WITH z AS (SELECT a FROM x) SELECT * FROM z; +WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z; + +WITH z AS (SELECT a FROM x), q AS (SELECT * FROM z) SELECT * FROM q; +WITH z AS (SELECT x.a AS a FROM x AS x), q AS (SELECT z.a AS a FROM z) SELECT q.a AS a FROM q; + +WITH z AS (SELECT * FROM x) SELECT * FROM z UNION SELECT * FROM z; +WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x) SELECT z.a AS a, z.b AS b FROM z UNION SELECT z.a AS a, z.b AS b FROM z; + +WITH z AS (SELECT * FROM x), q AS (SELECT b FROM z) SELECT b FROM q; +WITH z AS (SELECT x.a AS a, x.b AS b FROM x AS x), q AS (SELECT z.b AS b FROM z) SELECT q.b AS b FROM q; + +WITH z AS ((SELECT b FROM x UNION ALL SELECT b FROM y) ORDER BY b) SELECT * FROM z; +WITH z AS ((SELECT x.b AS b FROM x AS x UNION ALL SELECT y.b AS b FROM y AS y) ORDER BY b) SELECT z.b AS b FROM z; + +-------------------------------------- +-- Except and Replace +-------------------------------------- +SELECT * REPLACE(a AS d) FROM x; +SELECT x.a AS d, x.b AS b FROM x AS x; + +SELECT * EXCEPT(b) REPLACE(a AS d) FROM x; +SELECT x.a AS d FROM x AS x; + +SELECT x.* EXCEPT(a), y.* FROM x, y; +SELECT x.b AS b, y.b AS b, y.c AS c FROM x AS x, y AS y; + +SELECT * EXCEPT(a) FROM x; +SELECT x.b AS b FROM x AS x; + +-------------------------------------- +-- Using +-------------------------------------- +SELECT x.b FROM x JOIN y USING (b); +SELECT x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b; + +SELECT x.b FROM x JOIN y USING (b) JOIN z USING (b); +SELECT x.b AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b; + +SELECT b FROM x AS x2 JOIN y AS y2 USING (b); +SELECT COALESCE(x2.b, y2.b) AS b FROM x AS x2 JOIN y AS y2 ON x2.b = y2.b; + +SELECT b FROM x JOIN y USING (b) WHERE b = 1 and y.b = 2; +SELECT COALESCE(x.b, y.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b WHERE COALESCE(x.b, y.b) = 1 AND y.b = 2; + +SELECT b FROM x JOIN y USING (b) JOIN z USING (b); +SELECT COALESCE(x.b, y.b, z.b) AS b FROM x AS x JOIN y AS y ON x.b = y.b JOIN z AS z ON x.b = z.b; diff --git a/tests/fixtures/optimizer/qualify_columns__invalid.sql b/tests/fixtures/optimizer/qualify_columns__invalid.sql new file mode 100644 index 0000000..056b0e9 --- /dev/null +++ b/tests/fixtures/optimizer/qualify_columns__invalid.sql @@ -0,0 +1,14 @@ +SELECT a FROM zz; +SELECT * FROM zz; +SELECT z.a FROM x; +SELECT z.* FROM x; +SELECT x FROM x; +INSERT INTO x VALUES (1, 2); +SELECT a FROM x AS z JOIN y AS z; +WITH z AS (SELECT * FROM x) SELECT * FROM x AS z; +SELECT a FROM x JOIN (SELECT b FROM y WHERE y.b = x.c); +SELECT a FROM x AS y JOIN (SELECT a FROM y) AS q ON y.a = q.a; +SELECT q.a FROM (SELECT x.b FROM x) AS z JOIN (SELECT a FROM z) AS q ON z.b = q.a; +SELECT b FROM x AS a CROSS JOIN y AS b CROSS JOIN y AS c; +SELECT x.a FROM x JOIN y USING (a); +SELECT a, SUM(b) FROM x GROUP BY 3; diff --git a/tests/fixtures/optimizer/qualify_tables.sql b/tests/fixtures/optimizer/qualify_tables.sql new file mode 100644 index 0000000..2cea85d --- /dev/null +++ b/tests/fixtures/optimizer/qualify_tables.sql @@ -0,0 +1,17 @@ +SELECT 1 FROM z; +SELECT 1 FROM c.db.z AS z; + +SELECT 1 FROM y.z; +SELECT 1 FROM c.y.z AS z; + +SELECT 1 FROM x.y.z; +SELECT 1 FROM x.y.z AS z; + +SELECT 1 FROM x.y.z AS z; +SELECT 1 FROM x.y.z AS z; + +WITH a AS (SELECT 1 FROM z) SELECT 1 FROM a; +WITH a AS (SELECT 1 FROM c.db.z AS z) SELECT 1 FROM a; + +SELECT (SELECT y.c FROM y AS y) FROM x; +SELECT (SELECT y.c FROM c.db.y AS y) FROM c.db.x AS x; diff --git a/tests/fixtures/optimizer/quote_identities.sql b/tests/fixtures/optimizer/quote_identities.sql new file mode 100644 index 0000000..407b7f6 --- /dev/null +++ b/tests/fixtures/optimizer/quote_identities.sql @@ -0,0 +1,8 @@ +SELECT a FROM x; +SELECT "a" FROM "x"; + +SELECT "a" FROM "x"; +SELECT "a" FROM "x"; + +SELECT x.a AS a FROM db.x; +SELECT "x"."a" AS "a" FROM "db"."x"; diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql new file mode 100644 index 0000000..d7217cf --- /dev/null +++ b/tests/fixtures/optimizer/simplify.sql @@ -0,0 +1,350 @@ +-------------------------------------- +-- Conditions +-------------------------------------- +x AND x; +x; + +y OR y; +y; + +x AND NOT x; +FALSE; + +x OR NOT x; +TRUE; + +1 AND TRUE; +TRUE; + +TRUE AND TRUE; +TRUE; + +1 AND TRUE AND 1 AND 1; +TRUE; + +TRUE AND FALSE; +FALSE; + +FALSE AND FALSE; +FALSE; + +FALSE AND TRUE AND TRUE; +FALSE; + +x > y OR FALSE; +x > y; + +FALSE OR x = y; +x = y; + +1 = 1; +TRUE; + +1.0 = 1; +TRUE; + +'x' = 'y'; +FALSE; + +'x' = 'x'; +TRUE; + +NULL AND TRUE; +NULL; + +NULL AND NULL; +NULL; + +NULL OR TRUE; +TRUE; + +NULL OR NULL; +NULL; + +FALSE OR NULL; +NULL; + +NOT TRUE; +FALSE; + +NOT FALSE; +TRUE; + +NULL = NULL; +NULL; + +NOT (NOT TRUE); +TRUE; + +a AND (b OR b); +a AND b; + +a AND (b AND b); +a AND b; + +-------------------------------------- +-- Absorption +-------------------------------------- +(A OR B) AND (C OR NOT A); +(A OR B) AND (C OR NOT A); + +A AND (A OR B); +A; + +A AND D AND E AND (B OR A); +A AND D AND E; + +D AND A AND E AND (B OR A); +A AND D AND E; + +(A OR B) AND A; +A; + +C AND D AND (A OR B) AND E AND F AND A; +A AND C AND D AND E AND F; + +A OR (A AND B); +A; + +(A AND B) OR A; +A; + +A AND (NOT A OR B); +A AND B; + +(NOT A OR B) AND A; +A AND B; + +A OR (NOT A AND B); +A OR B; + +(A OR C) AND ((A OR C) OR B); +A OR C; + +(A OR C) AND (A OR B OR C); +A OR C; + +-------------------------------------- +-- Elimination +-------------------------------------- +(A AND B) OR (A AND NOT B); +A; + +(A AND B) OR (NOT A AND B); +B; + +(A AND NOT B) OR (A AND B); +A; + +(NOT A AND B) OR (A AND B); +B; + +(A OR B) AND (A OR NOT B); +A; + +(A OR B) AND (NOT A OR B); +B; + +(A OR NOT B) AND (A OR B); +A; + +(NOT A OR B) AND (A OR B); +B; + +(NOT A OR NOT B) AND (NOT A OR B); +NOT A; + +(NOT A OR NOT B) AND (NOT A OR NOT NOT B); +NOT A; + +E OR (A AND B) OR C OR D OR (A AND NOT B); +A OR C OR D OR E; + +-------------------------------------- +-- Associativity +-------------------------------------- +(A AND B) AND C; +A AND B AND C; + +A AND (B AND C); +A AND B AND C; + +(A OR B) OR C; +A OR B OR C; + +A OR (B OR C); +A OR B OR C; + +((A AND B) AND C) AND D; +A AND B AND C AND D; + +(((((A) AND B)) AND C)) AND D; +A AND B AND C AND D; + +-------------------------------------- +-- Comparison and Pruning +-------------------------------------- +A AND D AND B AND E AND F AND G AND E AND A; +A AND B AND D AND E AND F AND G; + +A AND NOT B AND C AND B; +FALSE; + +(a AND b AND c AND d) AND (d AND c AND b AND a); +a AND b AND c AND d; + +(c AND (a AND b)) AND ((b AND a) AND c); +a AND b AND c; + +(A AND B AND C) OR (C AND B AND A); +A AND B AND C; + +-------------------------------------- +-- Where removal +-------------------------------------- +SELECT x WHERE TRUE; +SELECT x; + +-------------------------------------- +-- Parenthesis removal +-------------------------------------- +(TRUE); +TRUE; + +(FALSE); +FALSE; + +(FALSE OR TRUE); +TRUE; + +TRUE OR (((FALSE) OR (TRUE)) OR FALSE); +TRUE; + +(NOT FALSE) AND (NOT TRUE); +FALSE; + +((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3); +TRUE; + +((NOT FALSE) AND (x = x)) AND (FALSE OR 1 <> 2); +TRUE; + +(('a' = 'a') AND TRUE and NOT FALSE); +TRUE; + +-------------------------------------- +-- Literals +-------------------------------------- +1 + 1; +2; + +0.06 + 0.01; +0.07; + +0.06 + 1; +1.06; + +1.2E+1 + 15E-3; +12.015; + +1.2E1 + 15E-3; +12.015; + +1 - 2; +-1; + +-1 + 3; +2; + +-(-1); +1; + +0.06 - 0.01; +0.05; + +3 * 4; +12; + +3.0 * 9; +27.0; + +0.03 * 0.73; +0.0219; + +1 / 3; +0; + +20.0 / 6; +3.333333333333333333333333333; + +10 / 5; +2; + +(1.0 * 3) * 4 - 2 * (5 / 2); +8.0; + +6 - 2 + 4 * 2 + a; +12 + a; + +a + 1 + 1 + 2; +a + 4; + +a + (1 + 1) + (10); +a + 12; + +5 + 4 * 3; +17; + +1 < 2; +TRUE; + +2 <= 2; +TRUE; + +2 >= 2; +TRUE; + +2 > 1; +TRUE; + +2 > 2.5; +FALSE; + +3 > 2.5; +TRUE; + +1 > NULL; +NULL; + +1 <= NULL; +NULL; + +1 IS NULL; +FALSE; + +NULL IS NULL; +TRUE; + +NULL IS NOT NULL; +FALSE; + +1 IS NOT NULL; +TRUE; + +date '1998-12-01' - interval '90' day; +CAST('1998-09-02' AS DATE); + +date '1998-12-01' + interval '1' week; +CAST('1998-12-08' AS DATE); + +interval '1' year + date '1998-01-01'; +CAST('1999-01-01' AS DATE); + +interval '1' year + date '1998-01-01' + 3 * 7 * 4; +CAST('1999-01-01' AS DATE) + 84; + +date '1998-12-01' - interval '90' foo; +CAST('1998-12-01' AS DATE) - INTERVAL '90' foo; + +date '1998-12-01' + interval '90' foo; +CAST('1998-12-01' AS DATE) + INTERVAL '90' foo; diff --git a/tests/fixtures/optimizer/tpc-h/customer.csv.gz b/tests/fixtures/optimizer/tpc-h/customer.csv.gz Binary files differnew file mode 100644 index 0000000..e0d149c --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/customer.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/lineitem.csv.gz b/tests/fixtures/optimizer/tpc-h/lineitem.csv.gz Binary files differnew file mode 100644 index 0000000..08e40d8 --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/lineitem.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/nation.csv.gz b/tests/fixtures/optimizer/tpc-h/nation.csv.gz Binary files differnew file mode 100644 index 0000000..d5bf6e3 --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/nation.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/orders.csv.gz b/tests/fixtures/optimizer/tpc-h/orders.csv.gz Binary files differnew file mode 100644 index 0000000..9b572bc --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/orders.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/part.csv.gz b/tests/fixtures/optimizer/tpc-h/part.csv.gz Binary files differnew file mode 100644 index 0000000..2dfdaa5 --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/part.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/partsupp.csv.gz b/tests/fixtures/optimizer/tpc-h/partsupp.csv.gz Binary files differnew file mode 100644 index 0000000..de9a2ce --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/partsupp.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/region.csv.gz b/tests/fixtures/optimizer/tpc-h/region.csv.gz Binary files differnew file mode 100644 index 0000000..3dbd31a --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/region.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/supplier.csv.gz b/tests/fixtures/optimizer/tpc-h/supplier.csv.gz Binary files differnew file mode 100644 index 0000000..8dad82a --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/supplier.csv.gz diff --git a/tests/fixtures/optimizer/tpc-h/tpc-h.sql b/tests/fixtures/optimizer/tpc-h/tpc-h.sql new file mode 100644 index 0000000..482e231 --- /dev/null +++ b/tests/fixtures/optimizer/tpc-h/tpc-h.sql @@ -0,0 +1,1810 @@ +-------------------------------------- +-- TPC-H 1 +-------------------------------------- +select + l_returnflag, + l_linestatus, + sum(l_quantity) as sum_qty, + sum(l_extendedprice) as sum_base_price, + sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + avg(l_quantity) as avg_qty, + avg(l_extendedprice) as avg_price, + avg(l_discount) as avg_disc, + count(*) as count_order +from + lineitem +where + CAST(l_shipdate AS DATE) <= date '1998-12-01' - interval '90' day +group by + l_returnflag, + l_linestatus +order by + l_returnflag, + l_linestatus; +SELECT + "lineitem"."l_returnflag" AS "l_returnflag", + "lineitem"."l_linestatus" AS "l_linestatus", + SUM("lineitem"."l_quantity") AS "sum_qty", + SUM("lineitem"."l_extendedprice") AS "sum_base_price", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "sum_disc_price", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) * ( + 1 + "lineitem"."l_tax" + )) AS "sum_charge", + AVG("lineitem"."l_quantity") AS "avg_qty", + AVG("lineitem"."l_extendedprice") AS "avg_price", + AVG("lineitem"."l_discount") AS "avg_disc", + COUNT(*) AS "count_order" +FROM "lineitem" AS "lineitem" +WHERE + CAST("lineitem"."l_shipdate" AS DATE) <= CAST('1998-09-02' AS DATE) +GROUP BY + "lineitem"."l_returnflag", + "lineitem"."l_linestatus" +ORDER BY + "l_returnflag", + "l_linestatus"; + +-------------------------------------- +-- TPC-H 2 +-------------------------------------- +select + s_acctbal, + s_name, + n_name, + p_partkey, + p_mfgr, + s_address, + s_phone, + s_comment +from + part, + supplier, + partsupp, + nation, + region +where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and p_size = 15 + and p_type like '%BRASS' + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + and ps_supplycost = ( + select + min(ps_supplycost) + from + partsupp, + supplier, + nation, + region + where + p_partkey = ps_partkey + and s_suppkey = ps_suppkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'EUROPE' + ) +order by + s_acctbal desc, + n_name, + s_name, + p_partkey +limit + 100; +WITH "_e_0" AS ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" +), "_e_1" AS ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'EUROPE' +) +SELECT + "supplier"."s_acctbal" AS "s_acctbal", + "supplier"."s_name" AS "s_name", + "nation"."n_name" AS "n_name", + "part"."p_partkey" AS "p_partkey", + "part"."p_mfgr" AS "p_mfgr", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone", + "supplier"."s_comment" AS "s_comment" +FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_mfgr" AS "p_mfgr", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size" + FROM "part" AS "part" + WHERE + "part"."p_size" = 15 + AND "part"."p_type" LIKE '%BRASS' +) AS "part" +LEFT JOIN ( + SELECT + MIN("partsupp"."ps_supplycost") AS "_col_0", + "partsupp"."ps_partkey" AS "_u_1" + FROM "_e_0" AS "partsupp" + CROSS JOIN "_e_1" AS "region" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" + ) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" + GROUP BY + "partsupp"."ps_partkey" +) AS "_u_0" + ON "part"."p_partkey" = "_u_0"."_u_1" +CROSS JOIN "_e_1" AS "region" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" +) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" +JOIN "_e_0" AS "partsupp" + ON "part"."p_partkey" = "partsupp"."ps_partkey" +JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_nationkey" AS "s_nationkey", + "supplier"."s_phone" AS "s_phone", + "supplier"."s_acctbal" AS "s_acctbal", + "supplier"."s_comment" AS "s_comment" + FROM "supplier" AS "supplier" +) AS "supplier" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + AND "supplier"."s_suppkey" = "partsupp"."ps_suppkey" +WHERE + "partsupp"."ps_supplycost" = "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL +ORDER BY + "s_acctbal" DESC, + "n_name", + "s_name", + "p_partkey" +LIMIT 100; + +-------------------------------------- +-- TPC-H 3 +-------------------------------------- +select + l_orderkey, + sum(l_extendedprice * (1 - l_discount)) as revenue, + CAST(o_orderdate AS STRING) AS o_orderdate, + o_shippriority +from + customer, + orders, + lineitem +where + c_mktsegment = 'BUILDING' + and c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate < '1995-03-15' + and l_shipdate > '1995-03-15' +group by + l_orderkey, + o_orderdate, + o_shippriority +order by + revenue desc, + o_orderdate +limit + 10; +SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue", + CAST("orders"."o_orderdate" AS TEXT) AS "o_orderdate", + "orders"."o_shippriority" AS "o_shippriority" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_mktsegment" AS "c_mktsegment" + FROM "customer" AS "customer" + WHERE + "customer"."c_mktsegment" = 'BUILDING' +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate", + "orders"."o_shippriority" AS "o_shippriority" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" < '1995-03-15' +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipdate" AS "l_shipdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" > '1995-03-15' +) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" +GROUP BY + "lineitem"."l_orderkey", + "orders"."o_orderdate", + "orders"."o_shippriority" +ORDER BY + "revenue" DESC, + "o_orderdate" +LIMIT 10; + +-------------------------------------- +-- TPC-H 4 +-------------------------------------- +select + o_orderpriority, + count(*) as order_count +from + orders +where + o_orderdate >= date '1993-07-01' + and o_orderdate < date '1993-07-01' + interval '3' month + and exists ( + select + * + from + lineitem + where + l_orderkey = o_orderkey + and l_commitdate < l_receiptdate + ) +group by + o_orderpriority +order by + o_orderpriority; +SELECT + "orders"."o_orderpriority" AS "o_orderpriority", + COUNT(*) AS "order_count" +FROM "orders" AS "orders" +LEFT JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" + GROUP BY + "lineitem"."l_orderkey" +) AS "_u_0" + ON "_u_0"."l_orderkey" = "orders"."o_orderkey" +WHERE + "orders"."o_orderdate" < CAST('1993-10-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1993-07-01' AS DATE) + AND NOT "_u_0"."l_orderkey" IS NULL +GROUP BY + "orders"."o_orderpriority" +ORDER BY + "o_orderpriority"; + +-------------------------------------- +-- TPC-H 5 +-------------------------------------- +select + n_name, + sum(l_extendedprice * (1 - l_discount)) as revenue +from + customer, + orders, + lineitem, + supplier, + nation, + region +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and l_suppkey = s_suppkey + and c_nationkey = s_nationkey + and s_nationkey = n_nationkey + and n_regionkey = r_regionkey + and r_name = 'ASIA' + and o_orderdate >= date '1994-01-01' + and o_orderdate < date '1994-01-01' + interval '1' year +group by + n_name +order by + revenue desc; +SELECT + "nation"."n_name" AS "n_name", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_nationkey" AS "c_nationkey" + FROM "customer" AS "customer" +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" < CAST('1995-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1994-01-01' AS DATE) +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +CROSS JOIN ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'ASIA' +) AS "region" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" +) AS "nation" + ON "nation"."n_regionkey" = "region"."r_regionkey" +JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +) AS "supplier" + ON "customer"."c_nationkey" = "supplier"."s_nationkey" + AND "supplier"."s_nationkey" = "nation"."n_nationkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount" + FROM "lineitem" AS "lineitem" +) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "lineitem"."l_suppkey" = "supplier"."s_suppkey" +GROUP BY + "nation"."n_name" +ORDER BY + "revenue" DESC; + +-------------------------------------- +-- TPC-H 6 +-------------------------------------- +select + sum(l_extendedprice * l_discount) as revenue +from + lineitem +where + l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + and l_discount between 0.06 - 0.01 and 0.06 + 0.01 + and l_quantity < 24; +SELECT + SUM("lineitem"."l_extendedprice" * "lineitem"."l_discount") AS "revenue" +FROM "lineitem" AS "lineitem" +WHERE + "lineitem"."l_discount" BETWEEN 0.05 AND 0.07 + AND "lineitem"."l_quantity" < 24 + AND "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE); + +-------------------------------------- +-- TPC-H 7 +-------------------------------------- +select + supp_nation, + cust_nation, + l_year, + sum(volume) as revenue +from + ( + select + n1.n_name as supp_nation, + n2.n_name as cust_nation, + extract(year from l_shipdate) as l_year, + l_extendedprice * (1 - l_discount) as volume + from + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2 + where + s_suppkey = l_suppkey + and o_orderkey = l_orderkey + and c_custkey = o_custkey + and s_nationkey = n1.n_nationkey + and c_nationkey = n2.n_nationkey + and ( + (n1.n_name = 'FRANCE' and n2.n_name = 'GERMANY') + or (n1.n_name = 'GERMANY' and n2.n_name = 'FRANCE') + ) + and l_shipdate between date '1995-01-01' and date '1996-12-31' + ) as shipping +group by + supp_nation, + cust_nation, + l_year +order by + supp_nation, + cust_nation, + l_year; +WITH "_e_0" AS ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'FRANCE' + OR "nation"."n_name" = 'GERMANY' +) +SELECT + "shipping"."supp_nation" AS "supp_nation", + "shipping"."cust_nation" AS "cust_nation", + "shipping"."l_year" AS "l_year", + SUM("shipping"."volume") AS "revenue" +FROM ( + SELECT + "n1"."n_name" AS "supp_nation", + "n2"."n_name" AS "cust_nation", + EXTRACT(year FROM "lineitem"."l_shipdate") AS "l_year", + "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) AS "volume" + FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipdate" AS "l_shipdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ) AS "lineitem" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey" + FROM "orders" AS "orders" + ) AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" + JOIN ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_nationkey" AS "c_nationkey" + FROM "customer" AS "customer" + ) AS "customer" + ON "customer"."c_custkey" = "orders"."o_custkey" + JOIN "_e_0" AS "n1" + ON "supplier"."s_nationkey" = "n1"."n_nationkey" + JOIN "_e_0" AS "n2" + ON "customer"."c_nationkey" = "n2"."n_nationkey" + AND ( + "n1"."n_name" = 'FRANCE' + OR "n2"."n_name" = 'FRANCE' + ) + AND ( + "n1"."n_name" = 'GERMANY' + OR "n2"."n_name" = 'GERMANY' + ) +) AS "shipping" +GROUP BY + "shipping"."supp_nation", + "shipping"."cust_nation", + "shipping"."l_year" +ORDER BY + "supp_nation", + "cust_nation", + "l_year"; + +-------------------------------------- +-- TPC-H 8 +-------------------------------------- +select + o_year, + sum(case + when nation = 'BRAZIL' then volume + else 0 + end) / sum(volume) as mkt_share +from + ( + select + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) as volume, + n2.n_name as nation + from + part, + supplier, + lineitem, + orders, + customer, + nation n1, + nation n2, + region + where + p_partkey = l_partkey + and s_suppkey = l_suppkey + and l_orderkey = o_orderkey + and o_custkey = c_custkey + and c_nationkey = n1.n_nationkey + and n1.n_regionkey = r_regionkey + and r_name = 'AMERICA' + and s_nationkey = n2.n_nationkey + and o_orderdate between date '1995-01-01' and date '1996-12-31' + and p_type = 'ECONOMY ANODIZED STEEL' + ) as all_nations +group by + o_year +order by + o_year; +SELECT + "all_nations"."o_year" AS "o_year", + SUM(CASE + WHEN "all_nations"."nation" = 'BRAZIL' + THEN "all_nations"."volume" + ELSE 0 + END) / SUM("all_nations"."volume") AS "mkt_share" +FROM ( + SELECT + EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) AS "volume", + "n2"."n_name" AS "nation" + FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_type" AS "p_type" + FROM "part" AS "part" + WHERE + "part"."p_type" = 'ECONOMY ANODIZED STEEL' + ) AS "part" + CROSS JOIN ( + SELECT + "region"."r_regionkey" AS "r_regionkey", + "region"."r_name" AS "r_name" + FROM "region" AS "region" + WHERE + "region"."r_name" = 'AMERICA' + ) AS "region" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_regionkey" AS "n_regionkey" + FROM "nation" AS "nation" + ) AS "n1" + ON "n1"."n_regionkey" = "region"."r_regionkey" + JOIN ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_nationkey" AS "c_nationkey" + FROM "customer" AS "customer" + ) AS "customer" + ON "customer"."c_nationkey" = "n1"."n_nationkey" + JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" BETWEEN CAST('1995-01-01' AS DATE) AND CAST('1996-12-31' AS DATE) + ) AS "orders" + ON "orders"."o_custkey" = "customer"."c_custkey" + JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount" + FROM "lineitem" AS "lineitem" + ) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" + AND "part"."p_partkey" = "lineitem"."l_partkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + ) AS "n2" + ON "supplier"."s_nationkey" = "n2"."n_nationkey" +) AS "all_nations" +GROUP BY + "all_nations"."o_year" +ORDER BY + "o_year"; + +-------------------------------------- +-- TPC-H 9 +-------------------------------------- +select + nation, + o_year, + sum(amount) as sum_profit +from + ( + select + n_name as nation, + extract(year from o_orderdate) as o_year, + l_extendedprice * (1 - l_discount) - ps_supplycost * l_quantity as amount + from + part, + supplier, + lineitem, + partsupp, + orders, + nation + where + s_suppkey = l_suppkey + and ps_suppkey = l_suppkey + and ps_partkey = l_partkey + and p_partkey = l_partkey + and o_orderkey = l_orderkey + and s_nationkey = n_nationkey + and p_name like '%green%' + ) as profit +group by + nation, + o_year +order by + nation, + o_year desc; +SELECT + "profit"."nation" AS "nation", + "profit"."o_year" AS "o_year", + SUM("profit"."amount") AS "sum_profit" +FROM ( + SELECT + "nation"."n_name" AS "nation", + EXTRACT(year FROM "orders"."o_orderdate") AS "o_year", + "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) - "partsupp"."ps_supplycost" * "lineitem"."l_quantity" AS "amount" + FROM ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_name" AS "p_name" + FROM "part" AS "part" + WHERE + "part"."p_name" LIKE '%green%' + ) AS "part" + JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_quantity" AS "l_quantity", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount" + FROM "lineitem" AS "lineitem" + ) AS "lineitem" + ON "part"."p_partkey" = "lineitem"."l_partkey" + JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" + ) AS "supplier" + ON "supplier"."s_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" + ) AS "partsupp" + ON "partsupp"."ps_partkey" = "lineitem"."l_partkey" + AND "partsupp"."ps_suppkey" = "lineitem"."l_suppkey" + JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + ) AS "orders" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" + JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + ) AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +) AS "profit" +GROUP BY + "profit"."nation", + "profit"."o_year" +ORDER BY + "nation", + "o_year" DESC; + +-------------------------------------- +-- TPC-H 10 +-------------------------------------- +select + c_custkey, + c_name, + sum(l_extendedprice * (1 - l_discount)) as revenue, + c_acctbal, + n_name, + c_address, + c_phone, + c_comment +from + customer, + orders, + lineitem, + nation +where + c_custkey = o_custkey + and l_orderkey = o_orderkey + and o_orderdate >= date '1993-10-01' + and o_orderdate < date '1993-10-01' + interval '3' month + and l_returnflag = 'R' + and c_nationkey = n_nationkey +group by + c_custkey, + c_name, + c_acctbal, + c_phone, + n_name, + c_address, + c_comment +order by + revenue desc +limit + 20; +SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_name" AS "c_name", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue", + "customer"."c_acctbal" AS "c_acctbal", + "nation"."n_name" AS "n_name", + "customer"."c_address" AS "c_address", + "customer"."c_phone" AS "c_phone", + "customer"."c_comment" AS "c_comment" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_name" AS "c_name", + "customer"."c_address" AS "c_address", + "customer"."c_nationkey" AS "c_nationkey", + "customer"."c_phone" AS "c_phone", + "customer"."c_acctbal" AS "c_acctbal", + "customer"."c_comment" AS "c_comment" + FROM "customer" AS "customer" +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderdate" < CAST('1994-01-01' AS DATE) + AND "orders"."o_orderdate" >= CAST('1993-10-01' AS DATE) +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_returnflag" AS "l_returnflag" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_returnflag" = 'R' +) AS "lineitem" + ON "lineitem"."l_orderkey" = "orders"."o_orderkey" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" +) AS "nation" + ON "customer"."c_nationkey" = "nation"."n_nationkey" +GROUP BY + "customer"."c_custkey", + "customer"."c_name", + "customer"."c_acctbal", + "customer"."c_phone", + "nation"."n_name", + "customer"."c_address", + "customer"."c_comment" +ORDER BY + "revenue" DESC +LIMIT 20; + +-------------------------------------- +-- TPC-H 11 +-------------------------------------- +select + ps_partkey, + sum(ps_supplycost * ps_availqty) as value +from + partsupp, + supplier, + nation +where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' +group by + ps_partkey having + sum(ps_supplycost * ps_availqty) > ( + select + sum(ps_supplycost * ps_availqty) * 0.0001 + from + partsupp, + supplier, + nation + where + ps_suppkey = s_suppkey + and s_nationkey = n_nationkey + and n_name = 'GERMANY' + ) +order by + value desc; +WITH "_e_0" AS ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +), "_e_1" AS ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'GERMANY' +) +SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") AS "value" +FROM ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_availqty" AS "ps_availqty", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" +) AS "partsupp" +JOIN "_e_0" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" +JOIN "_e_1" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +GROUP BY + "partsupp"."ps_partkey" +HAVING + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") > ( + SELECT + SUM("partsupp"."ps_supplycost" * "partsupp"."ps_availqty") * 0.0001 AS "_col_0" + FROM ( + SELECT + "partsupp"."ps_suppkey" AS "ps_suppkey", + "partsupp"."ps_availqty" AS "ps_availqty", + "partsupp"."ps_supplycost" AS "ps_supplycost" + FROM "partsupp" AS "partsupp" + ) AS "partsupp" + JOIN "_e_0" AS "supplier" + ON "partsupp"."ps_suppkey" = "supplier"."s_suppkey" + JOIN "_e_1" AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" + ) +ORDER BY + "value" DESC; + +-------------------------------------- +-- TPC-H 12 +-------------------------------------- +select + l_shipmode, + sum(case + when o_orderpriority = '1-URGENT' + or o_orderpriority = '2-HIGH' + then 1 + else 0 + end) as high_line_count, + sum(case + when o_orderpriority <> '1-URGENT' + and o_orderpriority <> '2-HIGH' + then 1 + else 0 + end) as low_line_count +from + orders, + lineitem +where + o_orderkey = l_orderkey + and l_shipmode in ('MAIL', 'SHIP') + and l_commitdate < l_receiptdate + and l_shipdate < l_commitdate + and l_receiptdate >= date '1994-01-01' + and l_receiptdate < date '1994-01-01' + interval '1' year +group by + l_shipmode +order by + l_shipmode; +SELECT + "lineitem"."l_shipmode" AS "l_shipmode", + SUM(CASE + WHEN "orders"."o_orderpriority" = '1-URGENT' + OR "orders"."o_orderpriority" = '2-HIGH' + THEN 1 + ELSE 0 + END) AS "high_line_count", + SUM(CASE + WHEN "orders"."o_orderpriority" <> '1-URGENT' + AND "orders"."o_orderpriority" <> '2-HIGH' + THEN 1 + ELSE 0 + END) AS "low_line_count" +FROM ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderpriority" AS "o_orderpriority" + FROM "orders" AS "orders" +) AS "orders" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_shipdate" AS "l_shipdate", + "lineitem"."l_commitdate" AS "l_commitdate", + "lineitem"."l_receiptdate" AS "l_receiptdate", + "lineitem"."l_shipmode" AS "l_shipmode" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_commitdate" < "lineitem"."l_receiptdate" + AND "lineitem"."l_receiptdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_receiptdate" >= CAST('1994-01-01' AS DATE) + AND "lineitem"."l_shipdate" < "lineitem"."l_commitdate" + AND "lineitem"."l_shipmode" IN ('MAIL', 'SHIP') +) AS "lineitem" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +GROUP BY + "lineitem"."l_shipmode" +ORDER BY + "l_shipmode"; + +-------------------------------------- +-- TPC-H 13 +-------------------------------------- +select + c_count, + count(*) as custdist +from + ( + select + c_custkey, + count(o_orderkey) + from + customer left outer join orders on + c_custkey = o_custkey + and o_comment not like '%special%requests%' + group by + c_custkey + ) as c_orders (c_custkey, c_count) +group by + c_count +order by + custdist desc, + c_count desc; +SELECT + "c_orders"."c_count" AS "c_count", + COUNT(*) AS "custdist" +FROM ( + SELECT + COUNT("orders"."o_orderkey") AS "c_count" + FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey" + FROM "customer" AS "customer" + ) AS "customer" + LEFT JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_comment" AS "o_comment" + FROM "orders" AS "orders" + WHERE + NOT "orders"."o_comment" LIKE '%special%requests%' + ) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" + GROUP BY + "customer"."c_custkey" +) AS "c_orders" +GROUP BY + "c_orders"."c_count" +ORDER BY + "custdist" DESC, + "c_count" DESC; + +-------------------------------------- +-- TPC-H 14 +-------------------------------------- +select + 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from + lineitem, + part +where + l_partkey = p_partkey + and l_shipdate >= date '1995-09-01' + and l_shipdate < date '1995-09-01' + interval '1' month; +SELECT + 100.00 * SUM(CASE + WHEN "part"."p_type" LIKE 'PROMO%' + THEN "lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + ) + ELSE 0 + END) / SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "promo_revenue" +FROM ( + SELECT + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipdate" AS "l_shipdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1995-10-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1995-09-01' AS DATE) +) AS "lineitem" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_type" AS "p_type" + FROM "part" AS "part" +) AS "part" + ON "lineitem"."l_partkey" = "part"."p_partkey"; + +-------------------------------------- +-- TPC-H 15 +-------------------------------------- +with revenue (supplier_no, total_revenue) as ( + select + l_suppkey, + sum(l_extendedprice * (1 - l_discount)) + from + lineitem + where + l_shipdate >= date '1996-01-01' + and l_shipdate < date '1996-01-01' + interval '3' month + group by + l_suppkey) +select + s_suppkey, + s_name, + s_address, + s_phone, + total_revenue +from + supplier, + revenue +where + s_suppkey = supplier_no + and total_revenue = ( + select + max(total_revenue) + from + revenue + ) +order by + s_suppkey; +WITH "revenue" AS ( + SELECT + "lineitem"."l_suppkey" AS "supplier_no", + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "total_revenue" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1996-04-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1996-01-01' AS DATE) + GROUP BY + "lineitem"."l_suppkey" +) +SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone", + "revenue"."total_revenue" AS "total_revenue" +FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_phone" AS "s_phone" + FROM "supplier" AS "supplier" +) AS "supplier" +JOIN "revenue" + ON "revenue"."total_revenue" = ( + SELECT + MAX("revenue"."total_revenue") AS "_col_0" + FROM "revenue" + ) + AND "supplier"."s_suppkey" = "revenue"."supplier_no" +ORDER BY + "s_suppkey"; + +-------------------------------------- +-- TPC-H 16 +-------------------------------------- +select + p_brand, + p_type, + p_size, + count(distinct ps_suppkey) as supplier_cnt +from + partsupp, + part +where + p_partkey = ps_partkey + and p_brand <> 'Brand#45' + and p_type not like 'MEDIUM POLISHED%' + and p_size in (49, 14, 23, 45, 19, 3, 36, 9) + and ps_suppkey not in ( + select + s_suppkey + from + supplier + where + s_comment like '%Customer%Complaints%' + ) +group by + p_brand, + p_type, + p_size +order by + supplier_cnt desc, + p_brand, + p_type, + p_size; +SELECT + "part"."p_brand" AS "p_brand", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size", + COUNT(DISTINCT "partsupp"."ps_suppkey") AS "supplier_cnt" +FROM ( + SELECT + "partsupp"."ps_partkey" AS "ps_partkey", + "partsupp"."ps_suppkey" AS "ps_suppkey" + FROM "partsupp" AS "partsupp" +) AS "partsupp" +LEFT JOIN ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey" + FROM "supplier" AS "supplier" + WHERE + "supplier"."s_comment" LIKE '%Customer%Complaints%' + GROUP BY + "supplier"."s_suppkey" +) AS "_u_0" + ON "partsupp"."ps_suppkey" = "_u_0"."s_suppkey" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_brand" AS "p_brand", + "part"."p_type" AS "p_type", + "part"."p_size" AS "p_size" + FROM "part" AS "part" + WHERE + "part"."p_brand" <> 'Brand#45' + AND "part"."p_size" IN (49, 14, 23, 45, 19, 3, 36, 9) + AND NOT "part"."p_type" LIKE 'MEDIUM POLISHED%' +) AS "part" + ON "part"."p_partkey" = "partsupp"."ps_partkey" +WHERE + "_u_0"."s_suppkey" IS NULL +GROUP BY + "part"."p_brand", + "part"."p_type", + "part"."p_size" +ORDER BY + "supplier_cnt" DESC, + "p_brand", + "p_type", + "p_size"; + +-------------------------------------- +-- TPC-H 17 +-------------------------------------- +select + sum(l_extendedprice) / 7.0 as avg_yearly +from + lineitem, + part +where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + ); +SELECT + SUM("lineitem"."l_extendedprice") / 7.0 AS "avg_yearly" +FROM ( + SELECT + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_quantity" AS "l_quantity", + "lineitem"."l_extendedprice" AS "l_extendedprice" + FROM "lineitem" AS "lineitem" +) AS "lineitem" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_brand" AS "p_brand", + "part"."p_container" AS "p_container" + FROM "part" AS "part" + WHERE + "part"."p_brand" = 'Brand#23' + AND "part"."p_container" = 'MED BOX' +) AS "part" + ON "part"."p_partkey" = "lineitem"."l_partkey" +LEFT JOIN ( + SELECT + 0.2 * AVG("lineitem"."l_quantity") AS "_col_0", + "lineitem"."l_partkey" AS "_u_1" + FROM "lineitem" AS "lineitem" + GROUP BY + "lineitem"."l_partkey" +) AS "_u_0" + ON "_u_0"."_u_1" = "part"."p_partkey" +WHERE + "lineitem"."l_quantity" < "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL; + +-------------------------------------- +-- TPC-H 18 +-------------------------------------- +select + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice, + sum(l_quantity) +from + customer, + orders, + lineitem +where + o_orderkey in ( + select + l_orderkey + from + lineitem + group by + l_orderkey having + sum(l_quantity) > 300 + ) + and c_custkey = o_custkey + and o_orderkey = l_orderkey +group by + c_name, + c_custkey, + o_orderkey, + o_orderdate, + o_totalprice +order by + o_totalprice desc, + o_orderdate +limit + 100; +SELECT + "customer"."c_name" AS "c_name", + "customer"."c_custkey" AS "c_custkey", + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderdate" AS "o_orderdate", + "orders"."o_totalprice" AS "o_totalprice", + SUM("lineitem"."l_quantity") AS "_col_5" +FROM ( + SELECT + "customer"."c_custkey" AS "c_custkey", + "customer"."c_name" AS "c_name" + FROM "customer" AS "customer" +) AS "customer" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_custkey" AS "o_custkey", + "orders"."o_totalprice" AS "o_totalprice", + "orders"."o_orderdate" AS "o_orderdate" + FROM "orders" AS "orders" +) AS "orders" + ON "customer"."c_custkey" = "orders"."o_custkey" +LEFT JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey" + FROM "lineitem" AS "lineitem" + GROUP BY + "lineitem"."l_orderkey", + "lineitem"."l_orderkey" + HAVING + SUM("lineitem"."l_quantity") > 300 +) AS "_u_0" + ON "orders"."o_orderkey" = "_u_0"."l_orderkey" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_quantity" AS "l_quantity" + FROM "lineitem" AS "lineitem" +) AS "lineitem" + ON "orders"."o_orderkey" = "lineitem"."l_orderkey" +WHERE + NOT "_u_0"."l_orderkey" IS NULL +GROUP BY + "customer"."c_name", + "customer"."c_custkey", + "orders"."o_orderkey", + "orders"."o_orderdate", + "orders"."o_totalprice" +ORDER BY + "o_totalprice" DESC, + "o_orderdate" +LIMIT 100; + +-------------------------------------- +-- TPC-H 19 +-------------------------------------- +select + sum(l_extendedprice* (1 - l_discount)) as revenue +from + lineitem, + part +where + ( + p_partkey = l_partkey + and p_brand = 'Brand#12' + and p_container in ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + and l_quantity >= 1 and l_quantity <= 11 + and p_size between 1 and 5 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container in ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + and l_quantity >= 10 and l_quantity <= 20 + and p_size between 1 and 10 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ) + or + ( + p_partkey = l_partkey + and p_brand = 'Brand#34' + and p_container in ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + and l_quantity >= 20 and l_quantity <= 30 + and p_size between 1 and 15 + and l_shipmode in ('AIR', 'AIR REG') + and l_shipinstruct = 'DELIVER IN PERSON' + ); +SELECT + SUM("lineitem"."l_extendedprice" * ( + 1 - "lineitem"."l_discount" + )) AS "revenue" +FROM ( + SELECT + "lineitem"."l_partkey" AS "l_partkey", + "lineitem"."l_quantity" AS "l_quantity", + "lineitem"."l_extendedprice" AS "l_extendedprice", + "lineitem"."l_discount" AS "l_discount", + "lineitem"."l_shipinstruct" AS "l_shipinstruct", + "lineitem"."l_shipmode" AS "l_shipmode" + FROM "lineitem" AS "lineitem" +) AS "lineitem" +JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey", + "part"."p_brand" AS "p_brand", + "part"."p_size" AS "p_size", + "part"."p_container" AS "p_container" + FROM "part" AS "part" +) AS "part" + ON ( + "part"."p_brand" = 'Brand#12' + AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 5 + ) + OR ( + "part"."p_brand" = 'Brand#23' + AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 10 + ) + OR ( + "part"."p_brand" = 'Brand#34' + AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 15 + ) +WHERE + ( + "lineitem"."l_quantity" <= 11 + AND "lineitem"."l_quantity" >= 1 + AND "lineitem"."l_shipinstruct" = 'DELIVER IN PERSON' + AND "lineitem"."l_shipmode" IN ('AIR', 'AIR REG') + AND "part"."p_brand" = 'Brand#12' + AND "part"."p_container" IN ('SM CASE', 'SM BOX', 'SM PACK', 'SM PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 5 + ) + OR ( + "lineitem"."l_quantity" <= 20 + AND "lineitem"."l_quantity" >= 10 + AND "lineitem"."l_shipinstruct" = 'DELIVER IN PERSON' + AND "lineitem"."l_shipmode" IN ('AIR', 'AIR REG') + AND "part"."p_brand" = 'Brand#23' + AND "part"."p_container" IN ('MED BAG', 'MED BOX', 'MED PKG', 'MED PACK') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 10 + ) + OR ( + "lineitem"."l_quantity" <= 30 + AND "lineitem"."l_quantity" >= 20 + AND "lineitem"."l_shipinstruct" = 'DELIVER IN PERSON' + AND "lineitem"."l_shipmode" IN ('AIR', 'AIR REG') + AND "part"."p_brand" = 'Brand#34' + AND "part"."p_container" IN ('LG CASE', 'LG BOX', 'LG PACK', 'LG PKG') + AND "part"."p_partkey" = "lineitem"."l_partkey" + AND "part"."p_size" BETWEEN 1 AND 15 + ); + +-------------------------------------- +-- TPC-H 20 +-------------------------------------- +select + s_name, + s_address +from + supplier, + nation +where + s_suppkey in ( + select + ps_suppkey + from + partsupp + where + ps_partkey in ( + select + p_partkey + from + part + where + p_name like 'forest%' + ) + and ps_availqty > ( + select + 0.5 * sum(l_quantity) + from + lineitem + where + l_partkey = ps_partkey + and l_suppkey = ps_suppkey + and l_shipdate >= date '1994-01-01' + and l_shipdate < date '1994-01-01' + interval '1' year + ) + ) + and s_nationkey = n_nationkey + and n_name = 'CANADA' +order by + s_name; +SELECT + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address" +FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_address" AS "s_address", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +) AS "supplier" +LEFT JOIN ( + SELECT + "partsupp"."ps_suppkey" AS "ps_suppkey" + FROM "partsupp" AS "partsupp" + LEFT JOIN ( + SELECT + 0.5 * SUM("lineitem"."l_quantity") AS "_col_0", + "lineitem"."l_partkey" AS "_u_1", + "lineitem"."l_suppkey" AS "_u_2" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_shipdate" < CAST('1995-01-01' AS DATE) + AND "lineitem"."l_shipdate" >= CAST('1994-01-01' AS DATE) + GROUP BY + "lineitem"."l_partkey", + "lineitem"."l_suppkey" + ) AS "_u_0" + ON "_u_0"."_u_1" = "partsupp"."ps_partkey" + AND "_u_0"."_u_2" = "partsupp"."ps_suppkey" + LEFT JOIN ( + SELECT + "part"."p_partkey" AS "p_partkey" + FROM "part" AS "part" + WHERE + "part"."p_name" LIKE 'forest%' + GROUP BY + "part"."p_partkey" + ) AS "_u_3" + ON "partsupp"."ps_partkey" = "_u_3"."p_partkey" + WHERE + "partsupp"."ps_availqty" > "_u_0"."_col_0" + AND NOT "_u_0"."_u_1" IS NULL + AND NOT "_u_0"."_u_2" IS NULL + AND NOT "_u_3"."p_partkey" IS NULL + GROUP BY + "partsupp"."ps_suppkey" +) AS "_u_4" + ON "supplier"."s_suppkey" = "_u_4"."ps_suppkey" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'CANADA' +) AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +WHERE + NOT "_u_4"."ps_suppkey" IS NULL +ORDER BY + "s_name"; + +-------------------------------------- +-- TPC-H 21 +-------------------------------------- +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit + 100; +SELECT + "supplier"."s_name" AS "s_name", + COUNT(*) AS "numwait" +FROM ( + SELECT + "supplier"."s_suppkey" AS "s_suppkey", + "supplier"."s_name" AS "s_name", + "supplier"."s_nationkey" AS "s_nationkey" + FROM "supplier" AS "supplier" +) AS "supplier" +JOIN ( + SELECT + "lineitem"."l_orderkey" AS "l_orderkey", + "lineitem"."l_suppkey" AS "l_suppkey", + "lineitem"."l_commitdate" AS "l_commitdate", + "lineitem"."l_receiptdate" AS "l_receiptdate" + FROM "lineitem" AS "lineitem" + WHERE + "lineitem"."l_receiptdate" > "lineitem"."l_commitdate" +) AS "l1" + ON "supplier"."s_suppkey" = "l1"."l_suppkey" +LEFT JOIN ( + SELECT + "l2"."l_orderkey" AS "l_orderkey", + ARRAY_AGG("l2"."l_suppkey") AS "_u_1" + FROM "lineitem" AS "l2" + GROUP BY + "l2"."l_orderkey" +) AS "_u_0" + ON "_u_0"."l_orderkey" = "l1"."l_orderkey" +LEFT JOIN ( + SELECT + "l3"."l_orderkey" AS "l_orderkey", + ARRAY_AGG("l3"."l_suppkey") AS "_u_3" + FROM "lineitem" AS "l3" + WHERE + "l3"."l_receiptdate" > "l3"."l_commitdate" + GROUP BY + "l3"."l_orderkey" +) AS "_u_2" + ON "_u_2"."l_orderkey" = "l1"."l_orderkey" +JOIN ( + SELECT + "orders"."o_orderkey" AS "o_orderkey", + "orders"."o_orderstatus" AS "o_orderstatus" + FROM "orders" AS "orders" + WHERE + "orders"."o_orderstatus" = 'F' +) AS "orders" + ON "orders"."o_orderkey" = "l1"."l_orderkey" +JOIN ( + SELECT + "nation"."n_nationkey" AS "n_nationkey", + "nation"."n_name" AS "n_name" + FROM "nation" AS "nation" + WHERE + "nation"."n_name" = 'SAUDI ARABIA' +) AS "nation" + ON "supplier"."s_nationkey" = "nation"."n_nationkey" +WHERE + ( + "_u_2"."l_orderkey" IS NULL + OR NOT ARRAY_ANY("_u_2"."_u_3", "_x" -> "_x" <> "l1"."l_suppkey") + ) + AND ARRAY_ANY("_u_0"."_u_1", "_x" -> "_x" <> "l1"."l_suppkey") + AND NOT "_u_0"."l_orderkey" IS NULL +GROUP BY + "supplier"."s_name" +ORDER BY + "numwait" DESC, + "s_name" +LIMIT 100; + +-------------------------------------- +-- TPC-H 22 +-------------------------------------- +select + cntrycode, + count(*) as numcust, + sum(c_acctbal) as totacctbal +from + ( + select + substring(c_phone, 1, 2) as cntrycode, + c_acctbal + from + customer + where + substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + and c_acctbal > ( + select + avg(c_acctbal) + from + customer + where + c_acctbal > 0.00 + and substring(c_phone, 1, 2) in + ('13', '31', '23', '29', '30', '18', '17') + ) + and not exists ( + select + * + from + orders + where + o_custkey = c_custkey + ) + ) as custsale +group by + cntrycode +order by + cntrycode; +SELECT + "custsale"."cntrycode" AS "cntrycode", + COUNT(*) AS "numcust", + SUM("custsale"."c_acctbal") AS "totacctbal" +FROM ( + SELECT + SUBSTRING("customer"."c_phone", 1, 2) AS "cntrycode", + "customer"."c_acctbal" AS "c_acctbal" + FROM "customer" AS "customer" + LEFT JOIN ( + SELECT + "orders"."o_custkey" AS "_u_1" + FROM "orders" AS "orders" + GROUP BY + "orders"."o_custkey" + ) AS "_u_0" + ON "_u_0"."_u_1" = "customer"."c_custkey" + WHERE + "_u_0"."_u_1" IS NULL + AND "customer"."c_acctbal" > ( + SELECT + AVG("customer"."c_acctbal") AS "_col_0" + FROM "customer" AS "customer" + WHERE + "customer"."c_acctbal" > 0.00 + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') + ) + AND SUBSTRING("customer"."c_phone", 1, 2) IN ('13', '31', '23', '29', '30', '18', '17') +) AS "custsale" +GROUP BY + "custsale"."cntrycode" +ORDER BY + "cntrycode"; diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql new file mode 100644 index 0000000..9c4bd27 --- /dev/null +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -0,0 +1,206 @@ +-------------------------------------- +-- Unnest Subqueries +-------------------------------------- +SELECT * +FROM x AS x +WHERE + x.a IN (SELECT y.a AS a FROM y) + AND x.a IN (SELECT y.b AS b FROM y) + AND x.a = ANY (SELECT y.a AS a FROM y) + AND x.a = (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) + AND x.a > (SELECT SUM(y.b) AS b FROM y WHERE x.a = y.a) + AND x.a <> ANY (SELECT y.a AS a FROM y WHERE y.a = x.a) + AND x.a NOT IN (SELECT y.a AS a FROM y WHERE y.a = x.a) + AND x.a IN (SELECT y.a AS a FROM y WHERE y.b = x.a) + AND x.a < (SELECT SUM(y.a) AS a FROM y WHERE y.a = x.a and y.a = x.b and y.b <> x.d) + AND EXISTS (SELECT y.a AS a, y.b AS b FROM y WHERE x.a = y.a) + AND x.a IN (SELECT y.a AS a FROM y LIMIT 10) + AND x.a IN (SELECT y.a AS a FROM y OFFSET 10) + AND x.a IN (SELECT y.a AS a, y.b AS b FROM y) + AND x.a > ANY (SELECT y.a FROM y) + AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a LIMIT 10) + AND x.a = (SELECT SUM(y.c) AS c FROM y WHERE y.a = x.a OFFSET 10) +; +SELECT + * +FROM x AS x +LEFT JOIN ( + SELECT + y.a AS a + FROM y + GROUP BY + y.a +) AS "_u_0" + ON x.a = "_u_0"."a" +LEFT JOIN ( + SELECT + y.b AS b + FROM y + GROUP BY + y.b +) AS "_u_1" + ON x.a = "_u_1"."b" +LEFT JOIN ( + SELECT + y.a AS a + FROM y + GROUP BY + y.a +) AS "_u_2" + ON x.a = "_u_2"."a" +LEFT JOIN ( + SELECT + SUM(y.b) AS b, + y.a AS _u_4 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_3" + ON x.a = "_u_3"."_u_4" +LEFT JOIN ( + SELECT + SUM(y.b) AS b, + y.a AS _u_6 + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_5" + ON x.a = "_u_5"."_u_6" +LEFT JOIN ( + SELECT + y.a AS a + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_7" + ON "_u_7".a = x.a +LEFT JOIN ( + SELECT + y.a AS a + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_8" + ON "_u_8".a = x.a +LEFT JOIN ( + SELECT + ARRAY_AGG(y.a) AS a, + y.b AS _u_10 + FROM y + WHERE + TRUE + GROUP BY + y.b +) AS "_u_9" + ON "_u_9"."_u_10" = x.a +LEFT JOIN ( + SELECT + SUM(y.a) AS a, + y.a AS _u_12, + ARRAY_AGG(y.b) AS _u_13 + FROM y + WHERE + TRUE + AND TRUE + AND TRUE + GROUP BY + y.a +) AS "_u_11" + ON "_u_11"."_u_12" = x.a + AND "_u_11"."_u_12" = x.b +LEFT JOIN ( + SELECT + y.a AS a + FROM y + WHERE + TRUE + GROUP BY + y.a +) AS "_u_14" + ON x.a = "_u_14".a +WHERE + NOT "_u_0"."a" IS NULL + AND NOT "_u_1"."b" IS NULL + AND NOT "_u_2"."a" IS NULL + AND ( + x.a = "_u_3".b + AND NOT "_u_3"."_u_4" IS NULL + ) + AND ( + x.a > "_u_5".b + AND NOT "_u_5"."_u_6" IS NULL + ) + AND ( + None = "_u_7".a + AND NOT "_u_7".a IS NULL + ) + AND NOT ( + x.a = "_u_8".a + AND NOT "_u_8".a IS NULL + ) + AND ( + ARRAY_ANY("_u_9".a, _x -> _x = x.a) + AND NOT "_u_9"."_u_10" IS NULL + ) + AND ( + ( + ( + x.a < "_u_11".a + AND NOT "_u_11"."_u_12" IS NULL + ) + AND NOT "_u_11"."_u_12" IS NULL + ) + AND ARRAY_ANY("_u_11"."_u_13", "_x" -> "_x" <> x.d) + ) + AND ( + NOT "_u_14".a IS NULL + AND NOT "_u_14".a IS NULL + ) + AND x.a IN ( + SELECT + y.a AS a + FROM y + LIMIT 10 + ) + AND x.a IN ( + SELECT + y.a AS a + FROM y + OFFSET 10 + ) + AND x.a IN ( + SELECT + y.a AS a, + y.b AS b + FROM y + ) + AND x.a > ANY ( + SELECT + y.a + FROM y + ) + AND x.a = ( + SELECT + SUM(y.c) AS c + FROM y + WHERE + y.a = x.a + LIMIT 10 + ) + AND x.a = ( + SELECT + SUM(y.c) AS c + FROM y + WHERE + y.a = x.a + OFFSET 10 + ); + diff --git a/tests/fixtures/partial.sql b/tests/fixtures/partial.sql new file mode 100644 index 0000000..c6be364 --- /dev/null +++ b/tests/fixtures/partial.sql @@ -0,0 +1,8 @@ +SELECT a FROM +SELECT a FROM x WHERE +SELECT a + +a * +SELECT a FROM x JOIN +SELECT a FROM x GROUP BY +WITH a AS (SELECT 1), b AS (SELECT 2) +SELECT FROM x diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql new file mode 100644 index 0000000..5ed74f4 --- /dev/null +++ b/tests/fixtures/pretty.sql @@ -0,0 +1,285 @@ +SELECT * FROM test; +SELECT + * +FROM test; + +WITH a AS ((SELECT 1 AS b) UNION ALL (SELECT 2 AS b)) SELECT * FROM a; +WITH a AS ( + ( + SELECT + 1 AS b + ) + UNION ALL + ( + SELECT + 2 AS b + ) +) +SELECT + * +FROM a; + +WITH cte1 AS ( + SELECT a, z and e AS b + FROM cte + WHERE x IN (1, 2, 3) AND z < -1 OR z > 1 AND w = 'AND' +), cte2 AS ( + SELECT RANK() OVER (PARTITION BY a, b ORDER BY x DESC) a, b + FROM cte + CROSS JOIN ( + SELECT 1 + UNION ALL + SELECT 2 + UNION ALL + SELECT CASE x AND 1 + 1 = 2 + WHEN TRUE THEN 1 AND 4 + 3 AND Z + WHEN x and y THEN 2 + ELSE 3 AND 4 AND g END + UNION ALL + SELECT 1 + FROM (SELECT 1) AS x, y, (SELECT 2) z + UNION ALL + SELECT MAX(COALESCE(x AND y, a and b and c, d and e)), FOO(CASE WHEN a and b THEN c and d ELSE 3 END) + GROUP BY x, GROUPING SETS (a, (b, c)) CUBE(y, z) + ) x +) +SELECT a, b c FROM ( + SELECT a w, 1 + 1 AS c + FROM foo + WHERE w IN (SELECT z FROM q) + GROUP BY a, b +) x +LEFT JOIN ( + SELECT a, b + FROM (SELECT * FROM bar WHERE (c > 1 AND d > 1) OR e > 1 GROUP BY a HAVING a > 1 LIMIT 10) z +) y ON x.a = y.b AND x.a > 1 OR (x.c = y.d OR x.c = y.e); +WITH cte1 AS ( + SELECT + a, + z + AND e AS b + FROM cte + WHERE + x IN (1, 2, 3) + AND z < -1 + OR z > 1 + AND w = 'AND' +), cte2 AS ( + SELECT + RANK() OVER (PARTITION BY a, b ORDER BY x DESC) AS a, + b + FROM cte + CROSS JOIN ( + SELECT + 1 + UNION ALL + SELECT + 2 + UNION ALL + SELECT + CASE x + AND 1 + 1 = 2 + WHEN TRUE + THEN 1 + AND 4 + 3 + AND Z + WHEN x + AND y + THEN 2 + ELSE 3 + AND 4 + AND g + END + UNION ALL + SELECT + 1 + FROM ( + SELECT + 1 + ) AS x, y, ( + SELECT + 2 + ) AS z + UNION ALL + SELECT + MAX(COALESCE(x + AND y, a + AND b + AND c, d + AND e)), + FOO(CASE + WHEN a + AND b + THEN c + AND d + ELSE 3 + END) + GROUP BY + x + GROUPING SETS ( + a, + (b, c) + ) + CUBE ( + y, + z + ) + ) AS x +) +SELECT + a, + b AS c +FROM ( + SELECT + a AS w, + 1 + 1 AS c + FROM foo + WHERE + w IN ( + SELECT + z + FROM q + ) + GROUP BY + a, + b +) AS x +LEFT JOIN ( + SELECT + a, + b + FROM ( + SELECT + * + FROM bar + WHERE + ( + c > 1 + AND d > 1 + ) + OR e > 1 + GROUP BY + a + HAVING + a > 1 + LIMIT 10 + ) AS z +) AS y + ON x.a = y.b + AND x.a > 1 + OR ( + x.c = y.d + OR x.c = y.e + ); + +SELECT myCol1, myCol2 FROM baseTable LATERAL VIEW OUTER explode(col1) myTable1 AS myCol1 LATERAL VIEW explode(col2) myTable2 AS myCol2 +where a > 1 and b > 2 or c > 3; + +SELECT + myCol1, + myCol2 +FROM baseTable +LATERAL VIEW OUTER +EXPLODE(col1) myTable1 AS myCol1 +LATERAL VIEW +EXPLODE(col2) myTable2 AS myCol2 +WHERE + a > 1 + AND b > 2 + OR c > 3; + +SELECT * FROM (WITH y AS ( SELECT 1 AS z) SELECT z from y) x; +SELECT + * +FROM ( + WITH y AS ( + SELECT + 1 AS z + ) + SELECT + z + FROM y +) AS x; + +INSERT OVERWRITE TABLE x VALUES (1, 2.0, '3.0'), (4, 5.0, '6.0'); +INSERT OVERWRITE TABLE x VALUES + (1, 2.0, '3.0'), + (4, 5.0, '6.0'); + +WITH regional_sales AS ( + SELECT region, SUM(amount) AS total_sales + FROM orders + GROUP BY region + ), top_regions AS ( + SELECT region + FROM regional_sales + WHERE total_sales > (SELECT SUM(total_sales)/10 FROM regional_sales) +) +SELECT region, +product, +SUM(quantity) AS product_units, +SUM(amount) AS product_sales +FROM orders +WHERE region IN (SELECT region FROM top_regions) +GROUP BY region, product; +WITH regional_sales AS ( + SELECT + region, + SUM(amount) AS total_sales + FROM orders + GROUP BY + region +), top_regions AS ( + SELECT + region + FROM regional_sales + WHERE + total_sales > ( + SELECT + SUM(total_sales) / 10 + FROM regional_sales + ) +) +SELECT + region, + product, + SUM(quantity) AS product_units, + SUM(amount) AS product_sales +FROM orders +WHERE + region IN ( + SELECT + region + FROM top_regions + ) +GROUP BY + region, + product; + +CREATE TABLE "t_customer_account" ( "id" int, "customer_id" int, "bank" varchar(100), "account_no" varchar(100)); +CREATE TABLE "t_customer_account" ( + "id" INT, + "customer_id" INT, + "bank" VARCHAR(100), + "account_no" VARCHAR(100) +); + +CREATE TABLE "t_customer_account" ( + "id" int(11) NOT NULL AUTO_INCREMENT, + "customer_id" int(11) DEFAULT NULL COMMENT '客户id', + "bank" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + "account_no" varchar(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY ("id") +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARACTER SET=utf8 COLLATE=utf8_bin COMMENT='客户账户表'; +CREATE TABLE "t_customer_account" ( + "id" INT(11) NOT NULL AUTO_INCREMENT, + "customer_id" INT(11) DEFAULT NULL COMMENT '客户id', + "bank" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '行别', + "account_no" VARCHAR(100) COLLATE utf8_bin DEFAULT NULL COMMENT '账号', + PRIMARY KEY("id") +) +ENGINE=InnoDB +AUTO_INCREMENT=1 +DEFAULT CHARACTER SET=utf8 +COLLATE=utf8_bin +COMMENT='客户账户表'; diff --git a/tests/helpers.py b/tests/helpers.py new file mode 100644 index 0000000..d4edb14 --- /dev/null +++ b/tests/helpers.py @@ -0,0 +1,130 @@ +import os + +FILE_DIR = os.path.dirname(__file__) +FIXTURES_DIR = os.path.join(FILE_DIR, "fixtures") + + +def _filter_comments(s): + return "\n".join( + [line for line in s.splitlines() if line and not line.startswith("--")] + ) + + +def _extract_meta(sql): + meta = {} + sql_lines = sql.split("\n") + i = 0 + while sql_lines[i].startswith("#"): + key, val = sql_lines[i].split(":", maxsplit=1) + meta[key.lstrip("#").strip()] = val.strip() + i += 1 + sql = "\n".join(sql_lines[i:]) + return sql, meta + + +def assert_logger_contains(message, logger, level="error"): + output = "\n".join( + str(args[0][0]) for args in getattr(logger, level).call_args_list + ) + assert message in output + + +def load_sql_fixtures(filename): + with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: + for sql in _filter_comments(f.read()).splitlines(): + yield sql + + +def load_sql_fixture_pairs(filename): + with open(os.path.join(FIXTURES_DIR, filename), encoding="utf-8") as f: + statements = _filter_comments(f.read()).split(";") + + size = len(statements) + + for i in range(0, size, 2): + if i + 1 < size: + sql = statements[i].strip() + sql, meta = _extract_meta(sql) + expected = statements[i + 1].strip() + yield meta, sql, expected + + +TPCH_SCHEMA = { + "lineitem": { + "l_orderkey": "uint64", + "l_partkey": "uint64", + "l_suppkey": "uint64", + "l_linenumber": "uint64", + "l_quantity": "float64", + "l_extendedprice": "float64", + "l_discount": "float64", + "l_tax": "float64", + "l_returnflag": "string", + "l_linestatus": "string", + "l_shipdate": "date32", + "l_commitdate": "date32", + "l_receiptdate": "date32", + "l_shipinstruct": "string", + "l_shipmode": "string", + "l_comment": "string", + }, + "orders": { + "o_orderkey": "uint64", + "o_custkey": "uint64", + "o_orderstatus": "string", + "o_totalprice": "float64", + "o_orderdate": "date32", + "o_orderpriority": "string", + "o_clerk": "string", + "o_shippriority": "int32", + "o_comment": "string", + }, + "customer": { + "c_custkey": "uint64", + "c_name": "string", + "c_address": "string", + "c_nationkey": "uint64", + "c_phone": "string", + "c_acctbal": "float64", + "c_mktsegment": "string", + "c_comment": "string", + }, + "part": { + "p_partkey": "uint64", + "p_name": "string", + "p_mfgr": "string", + "p_brand": "string", + "p_type": "string", + "p_size": "int32", + "p_container": "string", + "p_retailprice": "float64", + "p_comment": "string", + }, + "supplier": { + "s_suppkey": "uint64", + "s_name": "string", + "s_address": "string", + "s_nationkey": "uint64", + "s_phone": "string", + "s_acctbal": "float64", + "s_comment": "string", + }, + "partsupp": { + "ps_partkey": "uint64", + "ps_suppkey": "uint64", + "ps_availqty": "int32", + "ps_supplycost": "float64", + "ps_comment": "string", + }, + "nation": { + "n_nationkey": "uint64", + "n_name": "string", + "n_regionkey": "uint64", + "n_comment": "string", + }, + "region": { + "r_regionkey": "uint64", + "r_name": "string", + "r_comment": "string", + }, +} diff --git a/tests/test_build.py b/tests/test_build.py new file mode 100644 index 0000000..a4cffde --- /dev/null +++ b/tests/test_build.py @@ -0,0 +1,384 @@ +import unittest + +from sqlglot import and_, condition, exp, from_, not_, or_, parse_one, select + + +class TestBuild(unittest.TestCase): + def test_build(self): + for expression, sql, *dialect in [ + (lambda: select("x"), "SELECT x"), + (lambda: select("x", "y"), "SELECT x, y"), + (lambda: select("x").from_("tbl"), "SELECT x FROM tbl"), + (lambda: select("x", "y").from_("tbl"), "SELECT x, y FROM tbl"), + (lambda: select("x").select("y").from_("tbl"), "SELECT x, y FROM tbl"), + ( + lambda: select("x").select("y", append=False).from_("tbl"), + "SELECT y FROM tbl", + ), + (lambda: select("x").from_("tbl").from_("tbl2"), "SELECT x FROM tbl, tbl2"), + ( + lambda: select("x").from_("tbl, tbl2", "tbl3").from_("tbl4"), + "SELECT x FROM tbl, tbl2, tbl3, tbl4", + ), + ( + lambda: select("x").from_("tbl").from_("tbl2", append=False), + "SELECT x FROM tbl2", + ), + (lambda: select("SUM(x) AS y"), "SELECT SUM(x) AS y"), + ( + lambda: select("x").from_("tbl").where("x > 0"), + "SELECT x FROM tbl WHERE x > 0", + ), + ( + lambda: select("x").from_("tbl").where("x < 4 OR x > 5"), + "SELECT x FROM tbl WHERE x < 4 OR x > 5", + ), + ( + lambda: select("x").from_("tbl").where("x > 0").where("x < 9"), + "SELECT x FROM tbl WHERE x > 0 AND x < 9", + ), + ( + lambda: select("x").from_("tbl").where("x > 0", "x < 9"), + "SELECT x FROM tbl WHERE x > 0 AND x < 9", + ), + ( + lambda: select("x").from_("tbl").where(None).where(False, ""), + "SELECT x FROM tbl WHERE FALSE", + ), + ( + lambda: select("x") + .from_("tbl") + .where("x > 0") + .where("x < 9", append=False), + "SELECT x FROM tbl WHERE x < 9", + ), + ( + lambda: select("x", "y").from_("tbl").group_by("x"), + "SELECT x, y FROM tbl GROUP BY x", + ), + ( + lambda: select("x", "y").from_("tbl").group_by("x, y"), + "SELECT x, y FROM tbl GROUP BY x, y", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .group_by("x, y", "z") + .group_by("a"), + "SELECT x, y, z, a FROM tbl GROUP BY x, y, z, a", + ), + ( + lambda: select("x").distinct(True).from_("tbl"), + "SELECT DISTINCT x FROM tbl", + ), + (lambda: select("x").distinct(False).from_("tbl"), "SELECT x FROM tbl"), + ( + lambda: select("x").lateral("OUTER explode(y) tbl2 AS z").from_("tbl"), + "SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z", + ), + ( + lambda: select("x").from_("tbl").join("tbl2 ON tbl.y = tbl2.y"), + "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y", + ), + ( + lambda: select("x").from_("tbl").join("tbl2", on="tbl.y = tbl2.y"), + "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y", + ), + ( + lambda: select("x") + .from_("tbl") + .join("tbl2", on=["tbl.y = tbl2.y", "a = b"]), + "SELECT x FROM tbl JOIN tbl2 ON tbl.y = tbl2.y AND a = b", + ), + ( + lambda: select("x").from_("tbl").join("tbl2", join_type="left outer"), + "SELECT x FROM tbl LEFT OUTER JOIN tbl2", + ), + ( + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer"), + "SELECT x FROM tbl LEFT OUTER JOIN tbl2", + ), + ( + lambda: select("x") + .from_("tbl") + .join(exp.Table(this="tbl2"), join_type="left outer", join_alias="foo"), + "SELECT x FROM tbl LEFT OUTER JOIN tbl2 AS foo", + ), + ( + lambda: select("x") + .from_("tbl") + .join(select("y").from_("tbl2"), join_type="left outer"), + "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2)", + ), + ( + lambda: select("x") + .from_("tbl") + .join( + select("y").from_("tbl2").subquery("aliased"), + join_type="left outer", + ), + "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", + ), + ( + lambda: select("x") + .from_("tbl") + .join( + select("y").from_("tbl2"), + join_type="left outer", + join_alias="aliased", + ), + "SELECT x FROM tbl LEFT OUTER JOIN (SELECT y FROM tbl2) AS aliased", + ), + ( + lambda: select("x") + .from_("tbl") + .join(parse_one("left join x", into=exp.Join), on="a=b"), + "SELECT x FROM tbl LEFT JOIN x ON a = b", + ), + ( + lambda: select("x").from_("tbl").join("left join x", on="a=b"), + "SELECT x FROM tbl LEFT JOIN x ON a = b", + ), + ( + lambda: select("x") + .from_("tbl") + .join("select b from tbl2", on="a=b", join_type="left"), + "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) ON a = b", + ), + ( + lambda: select("x") + .from_("tbl") + .join( + "select b from tbl2", + on="a=b", + join_type="left", + join_alias="aliased", + ), + "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b", + ), + ( + lambda: select("x", "COUNT(y)") + .from_("tbl") + .group_by("x") + .having("COUNT(y) > 0"), + "SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0", + ), + ( + lambda: select("x").from_("tbl").order_by("y"), + "SELECT x FROM tbl ORDER BY y", + ), + ( + lambda: select("x").from_("tbl").cluster_by("y"), + "SELECT x FROM tbl CLUSTER BY y", + ), + ( + lambda: select("x").from_("tbl").sort_by("y"), + "SELECT x FROM tbl SORT BY y", + ), + ( + lambda: select("x").from_("tbl").order_by("x, y DESC"), + "SELECT x FROM tbl ORDER BY x, y DESC", + ), + ( + lambda: select("x").from_("tbl").cluster_by("x, y DESC"), + "SELECT x FROM tbl CLUSTER BY x, y DESC", + ), + ( + lambda: select("x").from_("tbl").sort_by("x, y DESC"), + "SELECT x FROM tbl SORT BY x, y DESC", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .order_by("x, y", "z") + .order_by("a"), + "SELECT x, y, z, a FROM tbl ORDER BY x, y, z, a", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .cluster_by("x, y", "z") + .cluster_by("a"), + "SELECT x, y, z, a FROM tbl CLUSTER BY x, y, z, a", + ), + ( + lambda: select("x", "y", "z", "a") + .from_("tbl") + .sort_by("x, y", "z") + .sort_by("a"), + "SELECT x, y, z, a FROM tbl SORT BY x, y, z, a", + ), + (lambda: select("x").from_("tbl").limit(10), "SELECT x FROM tbl LIMIT 10"), + ( + lambda: select("x").from_("tbl").offset(10), + "SELECT x FROM tbl OFFSET 10", + ), + ( + lambda: select("x").from_("tbl").with_("tbl", as_="SELECT x FROM tbl2"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_="SELECT x FROM tbl2", recursive=True), + "WITH RECURSIVE tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x").from_("tbl2")), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl (x, y)", as_=select("x", "y").from_("tbl2")), + "WITH tbl(x, y) AS (SELECT x, y FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x").from_("tbl2")) + .with_("tbl2", as_=select("x").from_("tbl3")), + "WITH tbl AS (SELECT x FROM tbl2), tbl2 AS (SELECT x FROM tbl3) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .from_("tbl") + .with_("tbl", as_=select("x", "y").from_("tbl2")) + .select("y"), + "WITH tbl AS (SELECT x, y FROM tbl2) SELECT x, y FROM tbl", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .group_by("x"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl GROUP BY x", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .order_by("x"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl ORDER BY x", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .limit(10), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl LIMIT 10", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .offset(10), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl OFFSET 10", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .join("tbl3"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl JOIN tbl3", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .distinct(), + "WITH tbl AS (SELECT x FROM tbl2) SELECT DISTINCT x FROM tbl", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .where("x > 10"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl WHERE x > 10", + ), + ( + lambda: select("x") + .with_("tbl", as_=select("x").from_("tbl2")) + .from_("tbl") + .having("x > 20"), + "WITH tbl AS (SELECT x FROM tbl2) SELECT x FROM tbl HAVING x > 20", + ), + (lambda: select("x").from_("tbl").subquery(), "(SELECT x FROM tbl)"), + ( + lambda: select("x").from_("tbl").subquery("y"), + "(SELECT x FROM tbl) AS y", + ), + ( + lambda: select("x").from_(select("x").from_("tbl").subquery()), + "SELECT x FROM (SELECT x FROM tbl)", + ), + (lambda: from_("tbl").select("x"), "SELECT x FROM tbl"), + ( + lambda: parse_one("SELECT a FROM tbl") + .assert_is(exp.Select) + .select("b"), + "SELECT a, b FROM tbl", + ), + ( + lambda: parse_one("SELECT * FROM y").assert_is(exp.Select).ctas("x"), + "CREATE TABLE x AS SELECT * FROM y", + ), + ( + lambda: parse_one("SELECT * FROM y") + .assert_is(exp.Select) + .ctas("foo.x", properties={"format": "parquet", "y": "2"}), + "CREATE TABLE foo.x STORED AS PARQUET TBLPROPERTIES ('y' = '2') AS SELECT * FROM y", + "hive", + ), + (lambda: and_("x=1", "y=1"), "x = 1 AND y = 1"), + (lambda: condition("x").and_("y['a']").and_("1"), "(x AND y['a']) AND 1"), + (lambda: condition("x=1").and_("y=1"), "x = 1 AND y = 1"), + (lambda: and_("x=1", "y=1", "z=1"), "x = 1 AND y = 1 AND z = 1"), + (lambda: condition("x=1").and_("y=1", "z=1"), "x = 1 AND y = 1 AND z = 1"), + (lambda: and_("x=1", and_("y=1", "z=1")), "x = 1 AND (y = 1 AND z = 1)"), + ( + lambda: condition("x=1").and_("y=1").and_("z=1"), + "(x = 1 AND y = 1) AND z = 1", + ), + (lambda: or_(and_("x=1", "y=1"), "z=1"), "(x = 1 AND y = 1) OR z = 1"), + ( + lambda: condition("x=1").and_("y=1").or_("z=1"), + "(x = 1 AND y = 1) OR z = 1", + ), + (lambda: or_("z=1", and_("x=1", "y=1")), "z = 1 OR (x = 1 AND y = 1)"), + ( + lambda: or_("z=1 OR a=1", and_("x=1", "y=1")), + "(z = 1 OR a = 1) OR (x = 1 AND y = 1)", + ), + (lambda: not_("x=1"), "NOT x = 1"), + (lambda: condition("x=1").not_(), "NOT x = 1"), + (lambda: condition("x=1").and_("y=1").not_(), "NOT (x = 1 AND y = 1)"), + ( + lambda: select("*").from_("x").where(condition("y=1").and_("z=1")), + "SELECT * FROM x WHERE y = 1 AND z = 1", + ), + ( + lambda: exp.subquery("select x from tbl", "foo") + .select("x") + .where("x > 0"), + "SELECT x FROM (SELECT x FROM tbl) AS foo WHERE x > 0", + ), + ( + lambda: exp.subquery( + "select x from tbl UNION select x from bar", "unioned" + ).select("x"), + "SELECT x FROM (SELECT x FROM tbl UNION SELECT x FROM bar) AS unioned", + ), + ]: + with self.subTest(sql): + self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_diff.py b/tests/test_diff.py new file mode 100644 index 0000000..cbd53b3 --- /dev/null +++ b/tests/test_diff.py @@ -0,0 +1,137 @@ +import unittest + +from sqlglot import parse_one +from sqlglot.diff import Insert, Keep, Move, Remove, Update, diff +from sqlglot.expressions import Join, to_identifier + + +class TestDiff(unittest.TestCase): + def test_simple(self): + self._validate_delta_only( + diff(parse_one("SELECT a + b"), parse_one("SELECT a - b")), + [ + Remove(parse_one("a + b")), # the Add node + Insert(parse_one("a - b")), # the Sub node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT a, b, c"), parse_one("SELECT a, c")), + [ + Remove(to_identifier("b", quoted=False)), # the Identifier node + Remove(parse_one("b")), # the Column node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT a, b"), parse_one("SELECT a, b, c")), + [ + Insert(to_identifier("c", quoted=False)), # the Identifier node + Insert(parse_one("c")), # the Column node + ], + ) + + self._validate_delta_only( + diff( + parse_one("SELECT a FROM table_one"), + parse_one("SELECT a FROM table_two"), + ), + [ + Update( + to_identifier("table_one", quoted=False), + to_identifier("table_two", quoted=False), + ), # the Identifier node + ], + ) + + def test_node_position_changed(self): + self._validate_delta_only( + diff(parse_one("SELECT a, b, c"), parse_one("SELECT c, a, b")), + [ + Move(parse_one("c")), # the Column node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT a + b"), parse_one("SELECT b + a")), + [ + Move(parse_one("a")), # the Column node + ], + ) + + self._validate_delta_only( + diff(parse_one("SELECT aaaa AND bbbb"), parse_one("SELECT bbbb AND aaaa")), + [ + Move(parse_one("aaaa")), # the Column node + ], + ) + + self._validate_delta_only( + diff( + parse_one("SELECT aaaa OR bbbb OR cccc"), + parse_one("SELECT cccc OR bbbb OR aaaa"), + ), + [ + Move(parse_one("aaaa")), # the Column node + Move(parse_one("cccc")), # the Column node + ], + ) + + def test_cte(self): + expr_src = """ + WITH + cte1 AS (SELECT a, b, LOWER(c) AS c FROM table_one WHERE d = 'filter'), + cte2 AS (SELECT d, e, f FROM table_two) + SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c + """ + expr_tgt = """ + WITH + cte1 AS (SELECT a, b, c FROM table_one WHERE d = 'different_filter'), + cte2 AS (SELECT d, e, f FROM table_two) + SELECT a, b, d, e FROM cte1 JOIN cte2 ON f = c + """ + + self._validate_delta_only( + diff(parse_one(expr_src), parse_one(expr_tgt)), + [ + Remove(parse_one("LOWER(c) AS c")), # the Alias node + Remove(to_identifier("c", quoted=False)), # the Identifier node + Remove(parse_one("LOWER(c)")), # the Lower node + Remove(parse_one("'filter'")), # the Literal node + Insert(parse_one("'different_filter'")), # the Literal node + ], + ) + + def test_join(self): + expr_src = "SELECT a, b FROM t1 LEFT JOIN t2 ON t1.key = t2.key" + expr_tgt = "SELECT a, b FROM t1 RIGHT JOIN t2 ON t1.key = t2.key" + + changes = diff(parse_one(expr_src), parse_one(expr_tgt)) + changes = _delta_only(changes) + + self.assertEqual(len(changes), 2) + self.assertTrue(isinstance(changes[0], Remove)) + self.assertTrue(isinstance(changes[1], Insert)) + self.assertTrue(all(isinstance(c.expression, Join) for c in changes)) + + def test_window_functions(self): + expr_src = parse_one("SELECT ROW_NUMBER() OVER (PARTITION BY a ORDER BY b)") + expr_tgt = parse_one("SELECT RANK() OVER (PARTITION BY a ORDER BY b)") + + self._validate_delta_only(diff(expr_src, expr_src), []) + + self._validate_delta_only( + diff(expr_src, expr_tgt), + [ + Remove(parse_one("ROW_NUMBER()")), # the Anonymous node + Insert(parse_one("RANK()")), # the Anonymous node + ], + ) + + def _validate_delta_only(self, actual_diff, expected_delta): + actual_delta = _delta_only(actual_diff) + self.assertEqual(set(actual_delta), set(expected_delta)) + + +def _delta_only(changes): + return [d for d in changes if not isinstance(d, Keep)] diff --git a/tests/test_docs.py b/tests/test_docs.py new file mode 100644 index 0000000..95aa814 --- /dev/null +++ b/tests/test_docs.py @@ -0,0 +1,30 @@ +import doctest +import inspect +import unittest + +import sqlglot +import sqlglot.optimizer +import sqlglot.transforms + + +def load_tests(loader, tests, ignore): + """ + This finds and runs all the doctests + """ + + modules = { + mod + for module in [sqlglot, sqlglot.transforms, sqlglot.optimizer] + for _, mod in inspect.getmembers(module, inspect.ismodule) + } + + assert len(modules) >= 20 + + for module in modules: + tests.addTests(doctest.DocTestSuite(module)) + + return tests + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_executor.py b/tests/test_executor.py new file mode 100644 index 0000000..9afa225 --- /dev/null +++ b/tests/test_executor.py @@ -0,0 +1,72 @@ +import unittest + +import duckdb +import pandas as pd +from pandas.testing import assert_frame_equal + +from sqlglot import exp, parse_one +from sqlglot.executor import execute +from sqlglot.executor.python import Python +from tests.helpers import FIXTURES_DIR, TPCH_SCHEMA, load_sql_fixture_pairs + +DIR = FIXTURES_DIR + "/optimizer/tpc-h/" + + +class TestExecutor(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.conn = duckdb.connect() + + for table in TPCH_SCHEMA: + cls.conn.execute( + f""" + CREATE VIEW {table} AS + SELECT * + FROM READ_CSV_AUTO('{DIR}{table}.csv.gz') + """ + ) + + cls.cache = {} + cls.sqls = [ + (sql, expected) + for _, sql, expected in load_sql_fixture_pairs("optimizer/tpc-h/tpc-h.sql") + ] + + @classmethod + def tearDownClass(cls): + cls.conn.close() + + def cached_execute(self, sql): + if sql not in self.cache: + self.cache[sql] = self.conn.execute(sql).fetchdf() + return self.cache[sql] + + def rename_anonymous(self, source, target): + for i, column in enumerate(source.columns): + if "_col_" in column: + source.rename(columns={column: target.columns[i]}, inplace=True) + + def test_py_dialect(self): + self.assertEqual(Python().generate(parse_one("'x '''")), r"'x \''") + + def test_optimized_tpch(self): + for sql, optimized in self.sqls[0:20]: + a = self.cached_execute(sql) + b = self.conn.execute(optimized).fetchdf() + self.rename_anonymous(b, a) + assert_frame_equal(a, b) + + def test_execute_tpch(self): + def to_csv(expression): + if isinstance(expression, exp.Table): + return parse_one( + f"READ_CSV('{DIR}{expression.name}.csv.gz', 'delimiter', '|') AS {expression.name}" + ) + return expression + + for sql, _ in self.sqls[0:3]: + a = self.cached_execute(sql) + sql = parse_one(sql).transform(to_csv).sql(pretty=True) + table = execute(sql, TPCH_SCHEMA) + b = pd.DataFrame(table.rows, columns=table.columns) + assert_frame_equal(a, b, check_dtype=False) diff --git a/tests/test_expressions.py b/tests/test_expressions.py new file mode 100644 index 0000000..eaef022 --- /dev/null +++ b/tests/test_expressions.py @@ -0,0 +1,415 @@ +import unittest + +from sqlglot import alias, exp, parse_one + + +class TestExpressions(unittest.TestCase): + def test_arg_key(self): + self.assertEqual(parse_one("sum(1)").find(exp.Literal).arg_key, "this") + + def test_depth(self): + self.assertEqual(parse_one("x(1)").find(exp.Literal).depth, 1) + + def test_eq(self): + self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a"')) + self.assertEqual(parse_one("`a`", read="hive"), parse_one('"a" ')) + self.assertEqual(parse_one("`a`.b", read="hive"), parse_one('"a"."b"')) + self.assertEqual(parse_one("select a, b+1"), parse_one("SELECT a, b + 1")) + self.assertEqual(parse_one("`a`.`b`.`c`", read="hive"), parse_one("a.b.c")) + self.assertNotEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c")) + self.assertEqual(parse_one("a.b.c.d", read="hive"), parse_one("a.b.c.d")) + self.assertEqual(parse_one("a + b * c - 1.0"), parse_one("a+b*c-1.0")) + self.assertNotEqual(parse_one("a + b * c - 1.0"), parse_one("a + b * c + 1.0")) + self.assertEqual(parse_one("a as b"), parse_one("a AS b")) + self.assertNotEqual(parse_one("a as b"), parse_one("a")) + self.assertEqual( + parse_one("ROW() OVER(Partition by y)"), + parse_one("ROW() OVER (partition BY y)"), + ) + self.assertEqual( + parse_one("TO_DATE(x)", read="hive"), parse_one("ts_or_ds_to_date(x)") + ) + + def test_find(self): + expression = parse_one("CREATE TABLE x STORED AS PARQUET AS SELECT * FROM y") + self.assertTrue(expression.find(exp.Create)) + self.assertFalse(expression.find(exp.Group)) + self.assertEqual( + [table.name for table in expression.find_all(exp.Table)], + ["x", "y"], + ) + + def test_find_all(self): + expression = parse_one( + """ + SELECT * + FROM ( + SELECT b.* + FROM a.b b + ) x + JOIN ( + SELECT c.foo + FROM a.c c + WHERE foo = 1 + ) y + ON x.c = y.foo + CROSS JOIN ( + SELECT * + FROM ( + SELECT d.bar + FROM d + ) nested + ) z + ON x.c = y.foo + """ + ) + + self.assertEqual( + [table.name for table in expression.find_all(exp.Table)], + ["b", "c", "d"], + ) + + expression = parse_one("select a + b + c + d") + + self.assertEqual( + [column.name for column in expression.find_all(exp.Column)], + ["d", "c", "a", "b"], + ) + self.assertEqual( + [column.name for column in expression.find_all(exp.Column, bfs=False)], + ["a", "b", "c", "d"], + ) + + def test_find_ancestor(self): + column = parse_one("select * from foo where (a + 1 > 2)").find(exp.Column) + self.assertIsInstance(column, exp.Column) + self.assertIsInstance(column.parent_select, exp.Select) + self.assertIsNone(column.find_ancestor(exp.Join)) + + def test_alias_or_name(self): + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) + self.assertEqual( + [e.alias_or_name for e in expression.expressions], + ["a", "B", "e", "*", "zz", "z"], + ) + self.assertEqual( + [e.alias_or_name for e in expression.args["from"].expressions], + ["bar", "baz"], + ) + + expression = parse_one( + """ + WITH first AS (SELECT * FROM foo), + second AS (SELECT * FROM bar) + SELECT * FROM first, second, (SELECT * FROM baz) AS third + """ + ) + + self.assertEqual( + [e.alias_or_name for e in expression.args["with"].expressions], + ["first", "second"], + ) + + self.assertEqual( + [e.alias_or_name for e in expression.args["from"].expressions], + ["first", "second", "third"], + ) + + def test_named_selects(self): + expression = parse_one( + "SELECT a, b AS B, c + d AS e, *, 'zz', 'zz' AS z FROM foo as bar, baz" + ) + self.assertEqual(expression.named_selects, ["a", "B", "e", "*", "zz", "z"]) + + expression = parse_one( + """ + WITH first AS (SELECT * FROM foo) + SELECT foo.bar, foo.baz as bazz, SUM(x) FROM first + """ + ) + self.assertEqual(expression.named_selects, ["bar", "bazz"]) + + expression = parse_one( + """ + SELECT foo, bar FROM first + UNION SELECT "ss" as foo, bar FROM second + UNION ALL SELECT foo, bazz FROM third + """ + ) + self.assertEqual(expression.named_selects, ["foo", "bar"]) + + def test_selects(self): + expression = parse_one("SELECT FROM x") + self.assertEqual(expression.selects, []) + + expression = parse_one("SELECT a FROM x") + self.assertEqual([s.sql() for s in expression.selects], ["a"]) + + expression = parse_one("SELECT a, b FROM x") + self.assertEqual([s.sql() for s in expression.selects], ["a", "b"]) + + def test_alias_column_names(self): + expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y") + subquery = expression.find(exp.Subquery) + self.assertEqual(subquery.alias_column_names, []) + + expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y(a)") + subquery = expression.find(exp.Subquery) + self.assertEqual(subquery.alias_column_names, ["a"]) + + expression = parse_one("SELECT * FROM (SELECT * FROM x) AS y(a, b)") + subquery = expression.find(exp.Subquery) + self.assertEqual(subquery.alias_column_names, ["a", "b"]) + + expression = parse_one("WITH y AS (SELECT * FROM x) SELECT * FROM y") + cte = expression.find(exp.CTE) + self.assertEqual(cte.alias_column_names, []) + + expression = parse_one("WITH y(a, b) AS (SELECT * FROM x) SELECT * FROM y") + cte = expression.find(exp.CTE) + self.assertEqual(cte.alias_column_names, ["a", "b"]) + + def test_ctes(self): + expression = parse_one("SELECT a FROM x") + self.assertEqual(expression.ctes, []) + + expression = parse_one("WITH x AS (SELECT a FROM y) SELECT a FROM x") + self.assertEqual([s.sql() for s in expression.ctes], ["x AS (SELECT a FROM y)"]) + + def test_hash(self): + self.assertEqual( + { + parse_one("select a.b"), + parse_one("1+2"), + parse_one('"a".b'), + parse_one("a.b.c.d"), + }, + { + parse_one("select a.b"), + parse_one("1+2"), + parse_one('"a"."b"'), + parse_one("a.b.c.d"), + }, + ) + + def test_sql(self): + self.assertEqual(parse_one("x + y * 2").sql(), "x + y * 2") + self.assertEqual( + parse_one('select "x"').sql(dialect="hive", pretty=True), "SELECT\n `x`" + ) + self.assertEqual( + parse_one("X + y").sql(identify=True, normalize=True), '"x" + "y"' + ) + self.assertEqual( + parse_one("SUM(X)").sql(identify=True, normalize=True), 'SUM("x")' + ) + + def test_transform_with_arguments(self): + expression = parse_one("a") + + def fun(node, alias_=True): + if alias_: + return parse_one("a AS a") + return node + + transformed_expression = expression.transform(fun) + self.assertEqual(transformed_expression.sql(dialect="presto"), "a AS a") + + transformed_expression_2 = expression.transform(fun, alias_=False) + self.assertEqual(transformed_expression_2.sql(dialect="presto"), "a") + + def test_transform_simple(self): + expression = parse_one("IF(a > 0, a, b)") + + def fun(node): + if isinstance(node, exp.Column) and node.name == "a": + return parse_one("c - 2") + return node + + actual_expression_1 = expression.transform(fun) + self.assertEqual( + actual_expression_1.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" + ) + self.assertIsNot(actual_expression_1, expression) + + actual_expression_2 = expression.transform(fun, copy=False) + self.assertEqual( + actual_expression_2.sql(dialect="presto"), "IF(c - 2 > 0, c - 2, b)" + ) + self.assertIs(actual_expression_2, expression) + + with self.assertRaises(ValueError): + parse_one("a").transform(lambda n: None) + + def test_transform_no_infinite_recursion(self): + expression = parse_one("a") + + def fun(node): + if isinstance(node, exp.Column) and node.name == "a": + return parse_one("FUN(a)") + return node + + self.assertEqual(expression.transform(fun).sql(), "FUN(a)") + + def test_transform_multiple_children(self): + expression = parse_one("SELECT * FROM x") + + def fun(node): + if isinstance(node, exp.Star): + return [parse_one(c) for c in ["a", "b"]] + return node + + self.assertEqual(expression.transform(fun).sql(), "SELECT a, b FROM x") + + def test_replace(self): + expression = parse_one("SELECT a, b FROM x") + expression.find(exp.Column).replace(parse_one("c")) + self.assertEqual(expression.sql(), "SELECT c, b FROM x") + expression.find(exp.Table).replace(parse_one("y")) + self.assertEqual(expression.sql(), "SELECT c, b FROM y") + + def test_walk(self): + expression = parse_one("SELECT * FROM (SELECT * FROM x)") + self.assertEqual(len(list(expression.walk())), 9) + self.assertEqual(len(list(expression.walk(bfs=False))), 9) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk()) + ) + self.assertTrue( + all(isinstance(e, exp.Expression) for e, _, _ in expression.walk(bfs=False)) + ) + + def test_functions(self): + self.assertIsInstance(parse_one("ABS(a)"), exp.Abs) + self.assertIsInstance(parse_one("APPROX_DISTINCT(a)"), exp.ApproxDistinct) + self.assertIsInstance(parse_one("ARRAY(a)"), exp.Array) + self.assertIsInstance(parse_one("ARRAY_AGG(a)"), exp.ArrayAgg) + self.assertIsInstance(parse_one("ARRAY_CONTAINS(a, 'a')"), exp.ArrayContains) + self.assertIsInstance(parse_one("ARRAY_SIZE(a)"), exp.ArraySize) + self.assertIsInstance(parse_one("AVG(a)"), exp.Avg) + self.assertIsInstance(parse_one("CEIL(a)"), exp.Ceil) + self.assertIsInstance(parse_one("CEILING(a)"), exp.Ceil) + self.assertIsInstance(parse_one("COALESCE(a, b)"), exp.Coalesce) + self.assertIsInstance(parse_one("COUNT(a)"), exp.Count) + self.assertIsInstance(parse_one("DATE_ADD(a, 1)"), exp.DateAdd) + self.assertIsInstance(parse_one("DATE_DIFF(a, 2)"), exp.DateDiff) + self.assertIsInstance(parse_one("DATE_STR_TO_DATE(a)"), exp.DateStrToDate) + self.assertIsInstance(parse_one("DAY(a)"), exp.Day) + self.assertIsInstance(parse_one("EXP(a)"), exp.Exp) + self.assertIsInstance(parse_one("FLOOR(a)"), exp.Floor) + self.assertIsInstance(parse_one("GREATEST(a, b)"), exp.Greatest) + self.assertIsInstance(parse_one("IF(a, b, c)"), exp.If) + self.assertIsInstance(parse_one("INITCAP(a)"), exp.Initcap) + self.assertIsInstance(parse_one("JSON_EXTRACT(a, '$.name')"), exp.JSONExtract) + self.assertIsInstance( + parse_one("JSON_EXTRACT_SCALAR(a, '$.name')"), exp.JSONExtractScalar + ) + self.assertIsInstance(parse_one("LEAST(a, b)"), exp.Least) + self.assertIsInstance(parse_one("LN(a)"), exp.Ln) + self.assertIsInstance(parse_one("LOG10(a)"), exp.Log10) + self.assertIsInstance(parse_one("MAX(a)"), exp.Max) + self.assertIsInstance(parse_one("MIN(a)"), exp.Min) + self.assertIsInstance(parse_one("MONTH(a)"), exp.Month) + self.assertIsInstance(parse_one("POW(a, 2)"), exp.Pow) + self.assertIsInstance(parse_one("POWER(a, 2)"), exp.Pow) + self.assertIsInstance(parse_one("QUANTILE(a, 0.90)"), exp.Quantile) + self.assertIsInstance(parse_one("REGEXP_LIKE(a, 'test')"), exp.RegexpLike) + self.assertIsInstance(parse_one("REGEXP_SPLIT(a, 'test')"), exp.RegexpSplit) + self.assertIsInstance(parse_one("ROUND(a)"), exp.Round) + self.assertIsInstance(parse_one("ROUND(a, 2)"), exp.Round) + self.assertIsInstance(parse_one("SPLIT(a, 'test')"), exp.Split) + self.assertIsInstance(parse_one("STR_POSITION(a, 'test')"), exp.StrPosition) + self.assertIsInstance(parse_one("STR_TO_UNIX(a, 'format')"), exp.StrToUnix) + self.assertIsInstance(parse_one("STRUCT_EXTRACT(a, 'test')"), exp.StructExtract) + self.assertIsInstance(parse_one("SUM(a)"), exp.Sum) + self.assertIsInstance(parse_one("SQRT(a)"), exp.Sqrt) + self.assertIsInstance(parse_one("STDDEV(a)"), exp.Stddev) + self.assertIsInstance(parse_one("STDDEV_POP(a)"), exp.StddevPop) + self.assertIsInstance(parse_one("STDDEV_SAMP(a)"), exp.StddevSamp) + self.assertIsInstance(parse_one("TIME_TO_STR(a, 'format')"), exp.TimeToStr) + self.assertIsInstance(parse_one("TIME_TO_TIME_STR(a)"), exp.Cast) + self.assertIsInstance(parse_one("TIME_TO_UNIX(a)"), exp.TimeToUnix) + self.assertIsInstance(parse_one("TIME_STR_TO_DATE(a)"), exp.TimeStrToDate) + self.assertIsInstance(parse_one("TIME_STR_TO_TIME(a)"), exp.TimeStrToTime) + self.assertIsInstance(parse_one("TIME_STR_TO_UNIX(a)"), exp.TimeStrToUnix) + self.assertIsInstance(parse_one("TS_OR_DS_ADD(a, 1, 'day')"), exp.TsOrDsAdd) + self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE(a)"), exp.TsOrDsToDate) + self.assertIsInstance(parse_one("TS_OR_DS_TO_DATE_STR(a)"), exp.Substring) + self.assertIsInstance(parse_one("UNIX_TO_STR(a, 'format')"), exp.UnixToStr) + self.assertIsInstance(parse_one("UNIX_TO_TIME(a)"), exp.UnixToTime) + self.assertIsInstance(parse_one("UNIX_TO_TIME_STR(a)"), exp.UnixToTimeStr) + self.assertIsInstance(parse_one("VARIANCE(a)"), exp.Variance) + self.assertIsInstance(parse_one("VARIANCE_POP(a)"), exp.VariancePop) + self.assertIsInstance(parse_one("YEAR(a)"), exp.Year) + + def test_column(self): + dot = parse_one("a.b.c") + column = dot.this + self.assertEqual(column.table, "a") + self.assertEqual(column.name, "b") + self.assertEqual(dot.text("expression"), "c") + + column = parse_one("a") + self.assertEqual(column.name, "a") + self.assertEqual(column.table, "") + + fields = parse_one("a.b.c.d") + self.assertIsInstance(fields, exp.Dot) + self.assertEqual(fields.text("expression"), "d") + self.assertEqual(fields.this.text("expression"), "c") + column = fields.find(exp.Column) + self.assertEqual(column.name, "b") + self.assertEqual(column.table, "a") + + column = parse_one("a[0].b") + self.assertIsInstance(column, exp.Dot) + self.assertIsInstance(column.this, exp.Bracket) + self.assertIsInstance(column.this.this, exp.Column) + + column = parse_one("a.*") + self.assertIsInstance(column, exp.Column) + self.assertIsInstance(column.this, exp.Star) + self.assertIsInstance(column.args["table"], exp.Identifier) + self.assertEqual(column.table, "a") + + self.assertIsInstance(parse_one("*"), exp.Star) + + def test_text(self): + column = parse_one("a.b.c") + self.assertEqual(column.text("expression"), "c") + self.assertEqual(column.text("y"), "") + self.assertEqual(parse_one("select * from x.y").find(exp.Table).text("db"), "x") + self.assertEqual(parse_one("select *").text("this"), "") + self.assertEqual(parse_one("1 + 1").text("this"), "1") + self.assertEqual(parse_one("'a'").text("this"), "a") + + def test_alias(self): + self.assertEqual(alias("foo", "bar").sql(), "foo AS bar") + self.assertEqual(alias("foo", "bar-1").sql(), 'foo AS "bar-1"') + self.assertEqual(alias("foo", "bar_1").sql(), "foo AS bar_1") + self.assertEqual(alias("foo * 2", "2bar").sql(), 'foo * 2 AS "2bar"') + self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS "_bar"') + self.assertEqual(alias("foo", "bar", quoted=True).sql(), 'foo AS "bar"') + + def test_unit(self): + unit = parse_one("timestamp_trunc(current_timestamp, week(thursday))") + self.assertIsNotNone(unit.find(exp.CurrentTimestamp)) + week = unit.find(exp.Week) + self.assertEqual(week.this, exp.Var(this="thursday")) + + def test_identifier(self): + self.assertTrue(exp.to_identifier('"x"').quoted) + self.assertFalse(exp.to_identifier("x").quoted) + + def test_function_normalizer(self): + self.assertEqual( + parse_one("HELLO()").sql(normalize_functions="lower"), "hello()" + ) + self.assertEqual( + parse_one("hello()").sql(normalize_functions="upper"), "HELLO()" + ) + self.assertEqual(parse_one("heLLO()").sql(normalize_functions=None), "heLLO()") + self.assertEqual(parse_one("SUM(x)").sql(normalize_functions="lower"), "sum(x)") + self.assertEqual(parse_one("sum(x)").sql(normalize_functions="upper"), "SUM(x)") diff --git a/tests/test_generator.py b/tests/test_generator.py new file mode 100644 index 0000000..d64a818 --- /dev/null +++ b/tests/test_generator.py @@ -0,0 +1,30 @@ +import unittest + +from sqlglot.expressions import Func +from sqlglot.parser import Parser +from sqlglot.tokens import Tokenizer + + +class TestGenerator(unittest.TestCase): + def test_fallback_function_sql(self): + class SpecialUDF(Func): + arg_types = {"a": True, "b": False} + + class NewParser(Parser): + FUNCTIONS = SpecialUDF.default_parser_mappings() + + tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a) FROM x") + expression = NewParser().parse(tokens)[0] + self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a) FROM x") + + def test_fallback_function_var_args_sql(self): + class SpecialUDF(Func): + arg_types = {"a": True, "expressions": False} + is_var_len_args = True + + class NewParser(Parser): + FUNCTIONS = SpecialUDF.default_parser_mappings() + + tokens = Tokenizer().tokenize("SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") + expression = NewParser().parse(tokens)[0] + self.assertEqual(expression.sql(), "SELECT SPECIAL_UDF(a, b, c, d + 1) FROM x") diff --git a/tests/test_helper.py b/tests/test_helper.py new file mode 100644 index 0000000..d37c03a --- /dev/null +++ b/tests/test_helper.py @@ -0,0 +1,31 @@ +import unittest + +from sqlglot.helper import tsort + + +class TestHelper(unittest.TestCase): + def test_tsort(self): + self.assertEqual(tsort({"a": []}), ["a"]) + self.assertEqual(tsort({"a": ["b", "b"]}), ["b", "a"]) + self.assertEqual(tsort({"a": ["b"]}), ["b", "a"]) + self.assertEqual(tsort({"a": ["c"], "b": [], "c": []}), ["c", "a", "b"]) + self.assertEqual( + tsort( + { + "a": ["b", "c"], + "b": ["c"], + "c": [], + "d": ["a"], + } + ), + ["c", "b", "a", "d"], + ) + + with self.assertRaises(ValueError): + tsort( + { + "a": ["b", "c"], + "b": ["a"], + "c": [], + } + ) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..40540b3 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,276 @@ +import unittest + +from sqlglot import optimizer, parse_one, table +from sqlglot.errors import OptimizeError +from sqlglot.optimizer.schema import MappingSchema, ensure_schema +from sqlglot.optimizer.scope import traverse_scope +from tests.helpers import TPCH_SCHEMA, load_sql_fixture_pairs, load_sql_fixtures + + +class TestOptimizer(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.schema = { + "x": { + "a": "INT", + "b": "INT", + }, + "y": { + "b": "INT", + "c": "INT", + }, + "z": { + "b": "INT", + "c": "INT", + }, + } + + def check_file(self, file, func, pretty=False, **kwargs): + for meta, sql, expected in load_sql_fixture_pairs(f"optimizer/{file}.sql"): + dialect = meta.get("dialect") + with self.subTest(sql): + self.assertEqual( + func(parse_one(sql, read=dialect), **kwargs).sql( + pretty=pretty, dialect=dialect + ), + expected, + ) + + def test_optimize(self): + schema = { + "x": {"a": "INT", "b": "INT"}, + "y": {"a": "INT", "b": "INT"}, + "z": {"a": "INT", "c": "INT"}, + } + + self.check_file("optimizer", optimizer.optimize, pretty=True, schema=schema) + + def test_isolate_table_selects(self): + self.check_file( + "isolate_table_selects", + optimizer.isolate_table_selects.isolate_table_selects, + ) + + def test_qualify_tables(self): + self.check_file( + "qualify_tables", + optimizer.qualify_tables.qualify_tables, + db="db", + catalog="c", + ) + + def test_normalize(self): + self.assertEqual( + optimizer.normalize.normalize( + parse_one("x AND (y OR z)"), + dnf=True, + ).sql(), + "(x AND y) OR (x AND z)", + ) + + self.check_file( + "normalize", + optimizer.normalize.normalize, + ) + + def test_qualify_columns(self): + def qualify_columns(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + return expression + + self.check_file("qualify_columns", qualify_columns, schema=self.schema) + + def test_qualify_columns__invalid(self): + for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"): + with self.subTest(sql): + with self.assertRaises(OptimizeError): + optimizer.qualify_columns.qualify_columns( + parse_one(sql), schema=self.schema + ) + + def test_quote_identities(self): + self.check_file("quote_identities", optimizer.quote_identities.quote_identities) + + def test_pushdown_projection(self): + def pushdown_projections(expression, **kwargs): + expression = optimizer.qualify_tables.qualify_tables(expression) + expression = optimizer.qualify_columns.qualify_columns(expression, **kwargs) + expression = optimizer.pushdown_projections.pushdown_projections(expression) + return expression + + self.check_file( + "pushdown_projections", pushdown_projections, schema=self.schema + ) + + def test_simplify(self): + self.check_file("simplify", optimizer.simplify.simplify) + + def test_unnest_subqueries(self): + self.check_file( + "unnest_subqueries", + optimizer.unnest_subqueries.unnest_subqueries, + pretty=True, + ) + + def test_pushdown_predicates(self): + self.check_file( + "pushdown_predicates", optimizer.pushdown_predicates.pushdown_predicates + ) + + def test_expand_multi_table_selects(self): + self.check_file( + "expand_multi_table_selects", + optimizer.expand_multi_table_selects.expand_multi_table_selects, + ) + + def test_optimize_joins(self): + self.check_file( + "optimize_joins", + optimizer.optimize_joins.optimize_joins, + ) + + def test_eliminate_subqueries(self): + self.check_file( + "eliminate_subqueries", + optimizer.eliminate_subqueries.eliminate_subqueries, + pretty=True, + ) + + def test_tpch(self): + self.check_file( + "tpc-h/tpc-h", optimizer.optimize, schema=TPCH_SCHEMA, pretty=True + ) + + def test_schema(self): + schema = ensure_schema( + { + "x": { + "a": "uint64", + } + } + ) + self.assertEqual( + schema.column_names( + table( + "x", + ) + ), + ["a"], + ) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x2")) + + schema = ensure_schema( + { + "db": { + "x": { + "a": "uint64", + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + schema = ensure_schema( + { + "c": { + "db": { + "x": { + "a": "uint64", + } + } + } + } + ) + self.assertEqual(schema.column_names(table("x", db="db", catalog="c")), ["a"]) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db")) + with self.assertRaises(ValueError): + schema.column_names(table("x")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db", catalog="c2")) + with self.assertRaises(ValueError): + schema.column_names(table("x", db="db2")) + with self.assertRaises(ValueError): + schema.column_names(table("x2", db="db")) + + schema = ensure_schema( + MappingSchema( + { + "x": { + "a": "uint64", + } + } + ) + ) + self.assertEqual(schema.column_names(table("x")), ["a"]) + + with self.assertRaises(OptimizeError): + ensure_schema({}) + + def test_file_schema(self): + expression = parse_one( + """ + SELECT * + FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') + """ + ) + self.assertEqual( + """ +SELECT + "_q_0"."n_nationkey" AS "n_nationkey", + "_q_0"."n_name" AS "n_name", + "_q_0"."n_regionkey" AS "n_regionkey", + "_q_0"."n_comment" AS "n_comment" +FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') AS "_q_0" +""".strip(), + optimizer.optimize(expression).sql(pretty=True), + ) + + def test_scope(self): + sql = """ + WITH q AS ( + SELECT x.b FROM x + ), r AS ( + SELECT y.b FROM y + ) + SELECT + r.b, + s.b + FROM r + JOIN ( + SELECT y.c AS b FROM y + ) s + ON s.b = r.b + WHERE s.b > (SELECT MAX(x.a) FROM x WHERE x.b = s.b) + """ + scopes = traverse_scope(parse_one(sql)) + self.assertEqual(len(scopes), 5) + self.assertEqual(scopes[0].expression.sql(), "SELECT x.b FROM x") + self.assertEqual(scopes[1].expression.sql(), "SELECT y.b FROM y") + self.assertEqual( + scopes[2].expression.sql(), "SELECT MAX(x.a) FROM x WHERE x.b = s.b" + ) + self.assertEqual(scopes[3].expression.sql(), "SELECT y.c AS b FROM y") + self.assertEqual(scopes[4].expression.sql(), parse_one(sql).sql()) + + self.assertEqual(set(scopes[4].sources), {"q", "r", "s"}) + self.assertEqual(len(scopes[4].columns), 6) + self.assertEqual(set(c.table for c in scopes[4].columns), {"r", "s"}) + self.assertEqual(scopes[4].source_columns("q"), []) + self.assertEqual(len(scopes[4].source_columns("r")), 2) + self.assertEqual(set(c.table for c in scopes[4].source_columns("r")), {"r"}) diff --git a/tests/test_parser.py b/tests/test_parser.py new file mode 100644 index 0000000..779083d --- /dev/null +++ b/tests/test_parser.py @@ -0,0 +1,195 @@ +import unittest +from unittest.mock import patch + +from sqlglot import Parser, exp, parse, parse_one +from sqlglot.errors import ErrorLevel, ParseError +from tests.helpers import assert_logger_contains + + +class TestParser(unittest.TestCase): + def test_parse_empty(self): + self.assertIsNone(parse_one("")) + + def test_parse_into(self): + self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join) + self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType) + self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType) + + def test_column(self): + columns = parse_one("select a, ARRAY[1] b, case when 1 then 1 end").find_all( + exp.Column + ) + assert len(list(columns)) == 1 + + self.assertIsNotNone(parse_one("date").find(exp.Column)) + + def test_table(self): + tables = [ + t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table) + ] + self.assertEqual(tables, ["a", "b.c", "d"]) + + def test_select(self): + self.assertIsNotNone( + parse_one("select * from (select 1) x order by x.y").args["order"] + ) + self.assertIsNotNone( + parse_one("select * from x where a = (select 1) order by x.y").args["order"] + ) + self.assertEqual( + len(parse_one("select * from (select 1) x cross join y").args["joins"]), 1 + ) + + def test_command(self): + expressions = parse("SET x = 1; ADD JAR s3://a; SELECT 1") + self.assertEqual(len(expressions), 3) + self.assertEqual(expressions[0].sql(), "SET x = 1") + self.assertEqual(expressions[1].sql(), "ADD JAR s3://a") + self.assertEqual(expressions[2].sql(), "SELECT 1") + + def test_identify(self): + expression = parse_one( + """ + SELECT a, "b", c AS c, d AS "D", e AS "y|z'" + FROM y."z" + """ + ) + + assert expression.expressions[0].text("this") == "a" + assert expression.expressions[1].text("this") == "b" + assert expression.expressions[2].text("alias") == "c" + assert expression.expressions[3].text("alias") == "D" + assert expression.expressions[4].text("alias") == "y|z'" + table = expression.args["from"].expressions[0] + assert table.args["this"].args["this"] == "z" + assert table.args["db"].args["this"] == "y" + + def test_multi(self): + expressions = parse( + """ + SELECT * FROM a; SELECT * FROM b; + """ + ) + + assert len(expressions) == 2 + assert ( + expressions[0].args["from"].expressions[0].args["this"].args["this"] == "a" + ) + assert ( + expressions[1].args["from"].expressions[0].args["this"].args["this"] == "b" + ) + + def test_expression(self): + ignore = Parser(error_level=ErrorLevel.IGNORE) + self.assertIsInstance(ignore.expression(exp.Hint, expressions=[""]), exp.Hint) + self.assertIsInstance(ignore.expression(exp.Hint, y=""), exp.Hint) + self.assertIsInstance(ignore.expression(exp.Hint), exp.Hint) + + default = Parser() + self.assertIsInstance(default.expression(exp.Hint, expressions=[""]), exp.Hint) + default.expression(exp.Hint, y="") + default.expression(exp.Hint) + self.assertEqual(len(default.errors), 3) + + warn = Parser(error_level=ErrorLevel.WARN) + warn.expression(exp.Hint, y="") + self.assertEqual(len(warn.errors), 2) + + def test_parse_errors(self): + with self.assertRaises(ParseError): + parse_one("IF(a > 0, a, b, c)") + + with self.assertRaises(ParseError): + parse_one("IF(a > 0)") + + with self.assertRaises(ParseError): + parse_one("WITH cte AS (SELECT * FROM x)") + + def test_space(self): + self.assertEqual( + parse_one("SELECT ROW() OVER(PARTITION BY x) FROM x GROUP BY y").sql(), + "SELECT ROW() OVER (PARTITION BY x) FROM x GROUP BY y", + ) + + self.assertEqual( + parse_one( + """SELECT * FROM x GROUP + BY y""" + ).sql(), + "SELECT * FROM x GROUP BY y", + ) + + def test_missing_by(self): + with self.assertRaises(ParseError): + parse_one("SELECT FROM x ORDER BY") + + def test_annotations(self): + expression = parse_one( + """ + SELECT + a #annotation1, + b as B #annotation2:testing , + "test#annotation",c#annotation3, d #annotation4, + e #, + f # space + FROM foo + """ + ) + + assert expression.expressions[0].name == "annotation1" + assert expression.expressions[1].name == "annotation2:testing " + assert expression.expressions[2].name == "test#annotation" + assert expression.expressions[3].name == "annotation3" + assert expression.expressions[4].name == "annotation4" + assert expression.expressions[5].name == "" + assert expression.expressions[6].name == " space" + + def test_pretty_config_override(self): + self.assertEqual(parse_one("SELECT col FROM x").sql(), "SELECT col FROM x") + with patch("sqlglot.pretty", True): + self.assertEqual( + parse_one("SELECT col FROM x").sql(), "SELECT\n col\nFROM x" + ) + + self.assertEqual( + parse_one("SELECT col FROM x").sql(pretty=True), "SELECT\n col\nFROM x" + ) + + @patch("sqlglot.parser.logger") + def test_comment_error_n(self, logger): + parse_one( + """CREATE TABLE x +( +-- test +)""", + error_level=ErrorLevel.WARN, + ) + + assert_logger_contains( + "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 4, Col: 1.", + logger, + ) + + @patch("sqlglot.parser.logger") + def test_comment_error_r(self, logger): + parse_one( + """CREATE TABLE x (-- test\r)""", + error_level=ErrorLevel.WARN, + ) + + assert_logger_contains( + "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Schema'>. Line 2, Col: 1.", + logger, + ) + + @patch("sqlglot.parser.logger") + def test_create_table_error(self, logger): + parse_one( + """CREATE TABLE PARTITION""", + error_level=ErrorLevel.WARN, + ) + + assert_logger_contains( + "Expected table name", + logger, + ) diff --git a/tests/test_time.py b/tests/test_time.py new file mode 100644 index 0000000..17821c2 --- /dev/null +++ b/tests/test_time.py @@ -0,0 +1,14 @@ +import unittest + +from sqlglot.time import format_time + + +class TestTime(unittest.TestCase): + def test_format_time(self): + self.assertEqual(format_time("", {}), "") + self.assertEqual(format_time(" ", {}), " ") + mapping = {"a": "b", "aa": "c"} + self.assertEqual(format_time("a", mapping), "b") + self.assertEqual(format_time("aa", mapping), "c") + self.assertEqual(format_time("aaada", mapping), "cbdb") + self.assertEqual(format_time("da", mapping), "db") diff --git a/tests/test_transforms.py b/tests/test_transforms.py new file mode 100644 index 0000000..2030109 --- /dev/null +++ b/tests/test_transforms.py @@ -0,0 +1,16 @@ +import unittest + +from sqlglot import parse_one +from sqlglot.transforms import unalias_group + + +class TestTime(unittest.TestCase): + def validate(self, transform, sql, target): + self.assertEqual(parse_one(sql).transform(transform).sql(), target) + + def test_unalias_group(self): + self.validate( + unalias_group, + "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, b, x.c, 4", + "SELECT a, b AS b, c AS c, 4 FROM x GROUP BY a, 2, x.c, 4", + ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py new file mode 100644 index 0000000..28bcc7a --- /dev/null +++ b/tests/test_transpile.py @@ -0,0 +1,349 @@ +import os +import unittest +from unittest import mock + +from sqlglot import parse_one, transpile +from sqlglot.errors import ErrorLevel, ParseError, UnsupportedError +from tests.helpers import ( + assert_logger_contains, + load_sql_fixture_pairs, + load_sql_fixtures, +) + + +class TestTranspile(unittest.TestCase): + file_dir = os.path.dirname(__file__) + fixtures_dir = os.path.join(file_dir, "fixtures") + maxDiff = None + + def validate(self, sql, target, **kwargs): + self.assertEqual(transpile(sql, **kwargs)[0], target) + + def test_alias(self): + for key in ("union", "filter", "over", "from", "join"): + with self.subTest(f"alias {key}"): + self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") + self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"') + + with self.assertRaises(ParseError): + self.validate(f"SELECT x {key}", "") + + def test_asc(self): + self.validate("SELECT x FROM y ORDER BY x ASC", "SELECT x FROM y ORDER BY x") + + def test_paren(self): + with self.assertRaises(ParseError): + transpile("1 + (2 + 3") + transpile("select f(") + + def test_some(self): + self.validate( + "SELECT * FROM x WHERE a = SOME (SELECT 1)", + "SELECT * FROM x WHERE a = ANY (SELECT 1)", + ) + + def test_space(self): + self.validate("SELECT MIN(3)>MIN(2)", "SELECT MIN(3) > MIN(2)") + self.validate("SELECT MIN(3)>=MIN(2)", "SELECT MIN(3) >= MIN(2)") + self.validate("SELECT 1>0", "SELECT 1 > 0") + self.validate("SELECT 3>=3", "SELECT 3 >= 3") + + def test_comments(self): + self.validate("SELECT 1 FROM foo -- comment", "SELECT 1 FROM foo") + self.validate("SELECT 1 /* inline */ FROM foo -- comment", "SELECT 1 FROM foo") + + self.validate( + """ + SELECT 1 -- comment + FROM foo -- comment + """, + "SELECT 1 FROM foo", + ) + + self.validate( + """ + SELECT 1 /* big comment + like this */ + FROM foo -- comment + """, + "SELECT 1 FROM foo", + ) + + def test_types(self): + self.validate("INT x", "CAST(x AS INT)") + self.validate("VARCHAR x y", "CAST(x AS VARCHAR) AS y") + self.validate("STRING x y", "CAST(x AS TEXT) AS y") + self.validate("x::INT", "CAST(x AS INT)") + self.validate("x::INTEGER", "CAST(x AS INT)") + self.validate("x::INT y", "CAST(x AS INT) AS y") + self.validate("x::INT AS y", "CAST(x AS INT) AS y") + self.validate("x::INT::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)") + self.validate("CAST(x::INT AS BOOLEAN)", "CAST(CAST(x AS INT) AS BOOLEAN)") + self.validate("CAST(x AS INT)::BOOLEAN", "CAST(CAST(x AS INT) AS BOOLEAN)") + + with self.assertRaises(ParseError): + transpile("x::z") + + def test_not_range(self): + self.validate("a NOT LIKE b", "NOT a LIKE b") + self.validate("a NOT BETWEEN b AND c", "NOT a BETWEEN b AND c") + self.validate("a NOT IN (1, 2)", "NOT a IN (1, 2)") + self.validate("a IS NOT NULL", "NOT a IS NULL") + self.validate("a LIKE TEXT y", "a LIKE CAST(y AS TEXT)") + + def test_extract(self): + self.validate( + "EXTRACT(day FROM '2020-01-01'::TIMESTAMP)", + "EXTRACT(day FROM CAST('2020-01-01' AS TIMESTAMP))", + ) + self.validate( + "EXTRACT(timezone FROM '2020-01-01'::TIMESTAMP)", + "EXTRACT(timezone FROM CAST('2020-01-01' AS TIMESTAMP))", + ) + self.validate( + "EXTRACT(year FROM '2020-01-01'::TIMESTAMP WITH TIME ZONE)", + "EXTRACT(year FROM CAST('2020-01-01' AS TIMESTAMPTZ))", + ) + self.validate( + "extract(month from '2021-01-31'::timestamp without time zone)", + "EXTRACT(month FROM CAST('2021-01-31' AS TIMESTAMP))", + ) + + def test_if(self): + self.validate( + "SELECT IF(a > 1, 1, 0) FROM foo", + "SELECT CASE WHEN a > 1 THEN 1 ELSE 0 END FROM foo", + ) + self.validate( + "SELECT IF a > 1 THEN b END", + "SELECT CASE WHEN a > 1 THEN b END", + ) + self.validate( + "SELECT IF a > 1 THEN b ELSE c END", + "SELECT CASE WHEN a > 1 THEN b ELSE c END", + ) + self.validate( + "SELECT IF(a > 1, 1) FROM foo", "SELECT CASE WHEN a > 1 THEN 1 END FROM foo" + ) + + def test_ignore_nulls(self): + self.validate("SELECT COUNT(x RESPECT NULLS)", "SELECT COUNT(x)") + + def test_time(self): + self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") + self.validate( + "TIMESTAMP WITH TIME ZONE '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMPTZ)" + ) + self.validate( + "TIMESTAMP(9) WITH TIME ZONE '2020-01-01'", + "CAST('2020-01-01' AS TIMESTAMPTZ(9))", + ) + self.validate( + "TIMESTAMP WITHOUT TIME ZONE '2020-01-01'", + "CAST('2020-01-01' AS TIMESTAMP)", + ) + self.validate("'2020-01-01'::TIMESTAMP", "CAST('2020-01-01' AS TIMESTAMP)") + self.validate( + "'2020-01-01'::TIMESTAMP WITHOUT TIME ZONE", + "CAST('2020-01-01' AS TIMESTAMP)", + ) + self.validate( + "'2020-01-01'::TIMESTAMP WITH TIME ZONE", + "CAST('2020-01-01' AS TIMESTAMPTZ)", + ) + self.validate( + "timestamp with time zone '2025-11-20 00:00:00+00' AT TIME ZONE 'Africa/Cairo'", + "CAST('2025-11-20 00:00:00+00' AS TIMESTAMPTZ) AT TIME ZONE 'Africa/Cairo'", + ) + + self.validate("DATE '2020-01-01'", "CAST('2020-01-01' AS DATE)") + self.validate("'2020-01-01'::DATE", "CAST('2020-01-01' AS DATE)") + self.validate("STR_TO_TIME('x', 'y')", "STRPTIME('x', 'y')", write="duckdb") + self.validate( + "STR_TO_UNIX('x', 'y')", "EPOCH(STRPTIME('x', 'y'))", write="duckdb" + ) + self.validate("TIME_TO_STR(x, 'y')", "STRFTIME(x, 'y')", write="duckdb") + self.validate("TIME_TO_UNIX(x)", "EPOCH(x)", write="duckdb") + self.validate( + "UNIX_TO_STR(123, 'y')", + "STRFTIME(TO_TIMESTAMP(CAST(123 AS BIGINT)), 'y')", + write="duckdb", + ) + self.validate( + "UNIX_TO_TIME(123)", + "TO_TIMESTAMP(CAST(123 AS BIGINT))", + write="duckdb", + ) + + self.validate( + "STR_TO_TIME(x, 'y')", + "CAST(FROM_UNIXTIME(UNIX_TIMESTAMP(x, 'y')) AS TIMESTAMP)", + write="hive", + ) + self.validate( + "STR_TO_TIME(x, 'yyyy-MM-dd HH:mm:ss')", + "CAST(x AS TIMESTAMP)", + write="hive", + ) + self.validate( + "STR_TO_TIME(x, 'yyyy-MM-dd')", + "CAST(x AS TIMESTAMP)", + write="hive", + ) + + self.validate( + "STR_TO_UNIX('x', 'y')", + "UNIX_TIMESTAMP('x', 'y')", + write="hive", + ) + self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="hive") + + self.validate("TIME_STR_TO_TIME(x)", "TIME_STR_TO_TIME(x)", write=None) + self.validate("TIME_STR_TO_UNIX(x)", "TIME_STR_TO_UNIX(x)", write=None) + self.validate("TIME_TO_TIME_STR(x)", "CAST(x AS TEXT)", write=None) + self.validate("TIME_TO_STR(x, 'y')", "TIME_TO_STR(x, 'y')", write=None) + self.validate("TIME_TO_UNIX(x)", "TIME_TO_UNIX(x)", write=None) + self.validate("UNIX_TO_STR(x, 'y')", "UNIX_TO_STR(x, 'y')", write=None) + self.validate("UNIX_TO_TIME(x)", "UNIX_TO_TIME(x)", write=None) + self.validate("UNIX_TO_TIME_STR(x)", "UNIX_TO_TIME_STR(x)", write=None) + self.validate("TIME_STR_TO_DATE(x)", "TIME_STR_TO_DATE(x)", write=None) + + self.validate("TIME_STR_TO_DATE(x)", "TO_DATE(x)", write="hive") + self.validate( + "UNIX_TO_STR(x, 'yyyy-MM-dd HH:mm:ss')", "FROM_UNIXTIME(x)", write="hive" + ) + self.validate( + "STR_TO_UNIX(x, 'yyyy-MM-dd HH:mm:ss')", "UNIX_TIMESTAMP(x)", write="hive" + ) + self.validate("IF(x > 1, x + 1)", "IF(x > 1, x + 1)", write="presto") + self.validate("IF(x > 1, 1 + 1)", "IF(x > 1, 1 + 1)", write="hive") + self.validate("IF(x > 1, 1, 0)", "IF(x > 1, 1, 0)", write="hive") + + self.validate( + "TIME_TO_UNIX(x)", + "UNIX_TIMESTAMP(x)", + write="hive", + ) + self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="hive") + self.validate( + "UNIX_TO_TIME(123)", + "FROM_UNIXTIME(123)", + write="hive", + ) + + self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto") + self.validate( + "STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto" + ) + self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto") + self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto") + self.validate( + "UNIX_TO_STR(123, 'y')", + "DATE_FORMAT(FROM_UNIXTIME(123), 'y')", + write="presto", + ) + self.validate("UNIX_TO_TIME(123)", "FROM_UNIXTIME(123)", write="presto") + + self.validate("STR_TO_TIME('x', 'y')", "TO_TIMESTAMP('x', 'y')", write="spark") + self.validate( + "STR_TO_UNIX('x', 'y')", "UNIX_TIMESTAMP('x', 'y')", write="spark" + ) + self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="spark") + + self.validate( + "TIME_TO_UNIX(x)", + "UNIX_TIMESTAMP(x)", + write="spark", + ) + self.validate("UNIX_TO_STR(123, 'y')", "FROM_UNIXTIME(123, 'y')", write="spark") + self.validate( + "UNIX_TO_TIME(123)", + "FROM_UNIXTIME(123)", + write="spark", + ) + self.validate( + "CREATE TEMPORARY TABLE test AS SELECT 1", + "CREATE TEMPORARY VIEW test AS SELECT 1", + write="spark", + ) + + @mock.patch("sqlglot.helper.logger") + def test_index_offset(self, mock_logger): + self.validate("x[0]", "x[1]", write="presto", identity=False) + self.validate("x[1]", "x[0]", read="presto", identity=False) + mock_logger.warning.assert_any_call("Applying array index offset (%s)", 1) + mock_logger.warning.assert_any_call("Applying array index offset (%s)", -1) + + def test_identity(self): + self.assertEqual(transpile("")[0], "") + for sql in load_sql_fixtures("identity.sql"): + with self.subTest(sql): + self.assertEqual(transpile(sql)[0], sql.strip()) + + def test_partial(self): + for sql in load_sql_fixtures("partial.sql"): + with self.subTest(sql): + self.assertEqual( + transpile(sql, error_level=ErrorLevel.IGNORE)[0], sql.strip() + ) + + def test_pretty(self): + for _, sql, pretty in load_sql_fixture_pairs("pretty.sql"): + with self.subTest(sql[:100]): + generated = transpile(sql, pretty=True)[0] + self.assertEqual(generated, pretty) + self.assertEqual(parse_one(sql), parse_one(pretty)) + + @mock.patch("sqlglot.parser.logger") + def test_error_level(self, logger): + invalid = "x + 1. (" + errors = [ + "Required keyword: 'expressions' missing for <class 'sqlglot.expressions.Aliases'>. Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + "Expecting ). Line 1, Col: 8.\n x + 1. \033[4m(\033[0m", + ] + + transpile(invalid, error_level=ErrorLevel.WARN) + for error in errors: + assert_logger_contains(error, logger) + + with self.assertRaises(ParseError) as ctx: + transpile(invalid, error_level=ErrorLevel.IMMEDIATE) + self.assertEqual(str(ctx.exception), errors[0]) + + with self.assertRaises(ParseError) as ctx: + transpile(invalid, error_level=ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception), "\n\n".join(errors)) + + more_than_max_errors = "((((" + expected = ( + "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Required keyword: 'this' missing for <class 'sqlglot.expressions.Paren'>. Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "Expecting ). Line 1, Col: 4.\n (((\033[4m(\033[0m\n\n" + "... and 2 more" + ) + with self.assertRaises(ParseError) as ctx: + transpile(more_than_max_errors, error_level=ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception), expected) + + @mock.patch("sqlglot.generator.logger") + def test_unsupported_level(self, logger): + def unsupported(level): + transpile( + "SELECT MAP(a, b), MAP(a, b), MAP(a, b), MAP(a, b)", + read="presto", + write="hive", + unsupported_level=level, + ) + + error = "Cannot convert array columns into map use SparkSQL instead." + + unsupported(ErrorLevel.WARN) + assert_logger_contains("\n".join([error] * 4), logger, level="warning") + + with self.assertRaises(UnsupportedError) as ctx: + unsupported(ErrorLevel.RAISE) + self.assertEqual(str(ctx.exception).count(error), 3) + + with self.assertRaises(UnsupportedError) as ctx: + unsupported(ErrorLevel.IMMEDIATE) + self.assertEqual(str(ctx.exception).count(error), 1) |