diff options
Diffstat (limited to '')
88 files changed, 1637 insertions, 436 deletions
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 3b6fdc6..8cd3634 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -20,7 +20,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r dev-requirements.txt + make install-dev - name: Run checks (linter, code style, tests) run: | - ./run_checks.sh + make check @@ -130,3 +130,8 @@ dmypy.json # PyCharm .idea/ + +# Visual Studio Code +.vscode + +.DS_STORE diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..17ffb61 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + - repo: local + hooks: + - id: autoflake + name: autoflake + entry: autoflake -i -r + language: system + types: [ python ] + require_serial: true + files: ^(sqlglot/|tests/|setup.py) + - id: isort + name: isort + entry: isort + language: system + types: [ python ] + files: ^(sqlglot/|tests/|setup.py) + require_serial: true + - id: black + name: black + entry: black --line-length 100 + language: system + types: [ python ] + require_serial: true + files: ^(sqlglot/|tests/|setup.py) + - id: mypy + name: mypy + entry: mypy + language: system + types: [ python ] + files: ^(sqlglot/|tests/) + require_serial: true diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 500bc70..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.linting.pylintEnabled": true -}
\ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dfca94..bf8699e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,37 @@ Changelog ========= +v10.4.0 +------ + +Changes: + +- Breaking: Removed the quote_identities optimizer rule. + +- New: ARRAYAGG, SUM, ARRAYANY support in the engine. SQLGlot is now able to execute all TPC-H queries. + +- Improvement: Transpile DATEDIFF to postgres. + +- Improvement: Right join pushdown fixes. + +- Improvement: Have Snowflake generate VALUES columns without quotes. + +- Improvement: Support NaN values in convert. + +- Improvement: Recursive CTE scope [fixes](https://github.com/tobymao/sqlglot/commit/bec36391d85152fa478222403d06beffa8d6ddfb). + + +v10.3.0 +------ + +Changes: + +- Breaking: Json ops changed to binary expressions. + +- New: Jinja tokenization. + +- Improvement: More robust type inference. + v10.2.0 ------ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2da2493 --- /dev/null +++ b/Makefile @@ -0,0 +1,24 @@ +.PHONY: install install-dev install-pre-commit test style check docs docs-serve + +install: + pip install -e . + +install-dev: + pip install -e ".[dev]" + +install-pre-commit: + pre-commit install + +test: + python -m unittest + +style: + pre-commit run --all-files + +check: style test + +docs: + pdoc/cli.py -o pdoc/docs + +docs-serve: + pdoc/cli.py @@ -1,8 +1,8 @@ # SQLGlot -SQLGlot is a no dependency Python SQL parser, transpiler, and optimizer. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. +SQLGlot is a no dependency Python SQL parser, transpiler, optimizer, and engine. It can be used to format SQL or translate between different dialects like [DuckDB](https://duckdb.org/), [Presto](https://prestodb.io/), [Spark](https://spark.apache.org/), [Snowflake](https://www.snowflake.com/en/), and [BigQuery](https://cloud.google.com/bigquery/). It aims to read a wide variety of SQL inputs and output syntactically correct SQL in the targeted dialects. -It is a very comprehensive generic SQL parser with a robust [test suite](tests). It is also quite [performant](#benchmarks) while being written purely in Python. +It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks) while being written purely in Python. You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. @@ -13,8 +13,7 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/ ## Table of Contents * [Install](#install) -* [Documentation](#documentation) -* [Run Tests and Lint](#run-tests-and-lint) +* [Get in Touch](#get-in-touch) * [Examples](#examples) * [Formatting and Transpiling](#formatting-and-transpiling) * [Metadata](#metadata) @@ -26,6 +25,8 @@ Contributions are very welcome in SQLGlot; read the [contribution guide](https:/ * [AST Diff](#ast-diff) * [Custom Dialects](#custom-dialects) * [SQL Execution](#sql-execution) +* [Documentation](#documentation) +* [Run Tests and Lint](#run-tests-and-lint) * [Benchmarks](#benchmarks) * [Optional Dependencies](#optional-dependencies) @@ -40,30 +41,17 @@ pip3 install sqlglot Or with a local checkout: ``` -pip3 install -e . +make install ``` Requirements for development (optional): ``` -pip3 install -r dev-requirements.txt -``` - -## Documentation - -SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: - -``` -pdoc sqlglot --docformat google -``` - -## Run Tests and Lint - -``` -# set `SKIP_INTEGRATION=1` to skip integration tests -./run_checks.sh +make install-dev ``` +## Get in Touch +We'd love to hear from you. Join our community [Slack channel](https://join.slack.com/t/tobiko-data/shared_invite/zt-1ma66d79v-a4dbf4DUpLAQJ8ptQrJygg)! ## Examples @@ -163,16 +151,16 @@ from sqlglot import parse_one, exp # print all column references (a and b) for column in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Column): - print(column.alias_or_name) + print(column.alias_or_name) # find all projections in select statements (a and c) for select in parse_one("SELECT a, b + 1 AS c FROM d").find_all(exp.Select): - for projection in select.expressions: - print(projection.alias_or_name) + for projection in select.expressions: + print(projection.alias_or_name) # find all tables (x, y, z) for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table): - print(table.name) + print(table.name) ``` ### Parser Errors @@ -274,7 +262,7 @@ transformed_tree.sql() ### SQL Optimizer -SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example: +SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](https://github.com/tobymao/sqlglot/blob/main/sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example: ```python import sqlglot @@ -292,7 +280,7 @@ print( ) ``` -``` +```sql SELECT ( "x"."A" OR "x"."B" OR "x"."C" @@ -351,9 +339,11 @@ diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d")) ] ``` +See also: [Semantic Diff for SQL](https://github.com/tobymao/sqlglot/blob/main/posts/sql_diff.md). + ### Custom Dialects -[Dialects](sqlglot/dialects) can be added by subclassing `Dialect`: +[Dialects](https://github.com/tobymao/sqlglot/tree/main/sqlglot/dialects) can be added by subclassing `Dialect`: ```python from sqlglot import exp @@ -391,7 +381,7 @@ class Custom(Dialect): print(Dialect["custom"]) ``` -```python +``` <class '__main__.Custom'> ``` @@ -442,9 +432,23 @@ user_id price 2 3.0 ``` +## Documentation + +SQLGlot uses [pdocs](https://pdoc.dev/) to serve its API documentation: + +``` +make docs-serve +``` + +## Run Tests and Lint + +``` +make check # Set SKIP_INTEGRATION=1 to skip integration tests +``` + ## Benchmarks -[Benchmarks](benchmarks) run on Python 3.10.5 in seconds. +[Benchmarks](https://github.com/tobymao/sqlglot/blob/main/benchmarks/bench.py) run on Python 3.10.5 in seconds. | Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide | | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index aa7d31f..0000000 --- a/dev-requirements.txt +++ /dev/null @@ -1,9 +0,0 @@ -autoflake -black -duckdb -isort -mypy -pandas -pyspark -python-dateutil -pdoc diff --git a/pdoc/cli.py b/pdoc/cli.py new file mode 100755 index 0000000..72a986d --- /dev/null +++ b/pdoc/cli.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python3 + +from importlib import import_module +from pathlib import Path +from unittest import mock + +from pdoc.__main__ import cli, parser + +# Need this import or else import_module doesn't work +import sqlglot + + +def mocked_import(*args, **kwargs): + """Return a MagicMock if import fails for any reason""" + try: + return import_module(*args, **kwargs) + except Exception: + mocked_module = mock.MagicMock() + mocked_module.__name__ = args[0] + return mocked_module + + +if __name__ == "__main__": + # Mock uninstalled dependencies so pdoc can still work + with mock.patch("importlib.import_module", side_effect=mocked_import): + opts = parser.parse_args() + opts.docformat = "google" + opts.modules = ["sqlglot"] + opts.footer_text = "Copyright (c) 2022 Toby Mao" + opts.template_directory = Path(__file__).parent.joinpath("templates").absolute() + opts.edit_url = ["sqlglot=https://github.com/tobymao/sqlglot/"] + + with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}): + cli() diff --git a/pdoc/docs/expressions.md b/pdoc/docs/expressions.md new file mode 100644 index 0000000..c82674b --- /dev/null +++ b/pdoc/docs/expressions.md @@ -0,0 +1,41 @@ +# Expressions + +Every AST node in SQLGlot is represented by a subclass of `Expression`. Each such expression encapsulates any necessary context, such as its child expressions, their names, or arg keys, and whether each child expression is optional or not. + +Furthermore, the following attributes are common across all expressions: + +#### key + +A unique key for each class in the `Expression` hierarchy. This is useful for hashing and representing expressions as strings. + +#### args + +A dictionary used for mapping child arg keys, to the corresponding expressions. A value in this mapping is usually either a single or a list of `Expression` instances, but SQLGlot doesn't impose any constraints on the actual type of the value. + +#### arg_types + +A dictionary used for mapping arg keys to booleans that determine whether the corresponding expressions are optional or not. Consider the following example: + +```python +class Limit(Expression): + arg_types = {"this": False, "expression": True} + +``` + +Here, `Limit` declares that it expects to have one optional and one required child expression, which can be referenced through `this` and `expression`, respectively. The arg keys are generally arbitrary, but there are helper methods for keys like `this`, `expression` and `expressions` that abstract away dictionary lookups and related checks. For this reason, these keys are common throughout SQLGlot's codebase. + +#### parent + +A reference to the parent expression (may be `None`). + +#### arg_key + +The arg key an expression is associated with, i.e. the name its parent expression uses to refer to it. + +#### comments + +A list of comments that are associated with a given expression. This is used in order to preserve comments when transpiling SQL code. + +#### type + +The data type of an expression, as inferred by SQLGlot's optimizer. diff --git a/pdoc/templates/module.html.jinja2 b/pdoc/templates/module.html.jinja2 new file mode 100644 index 0000000..e37ae01 --- /dev/null +++ b/pdoc/templates/module.html.jinja2 @@ -0,0 +1,6 @@ +{% extends "default/module.html.jinja2" %} + +{% if module.docstring %} + {% macro module_name() %} + {% endmacro %} +{% endif %} diff --git a/posts/python_sql_engine.md b/posts/python_sql_engine.md new file mode 100644 index 0000000..1c74680 --- /dev/null +++ b/posts/python_sql_engine.md @@ -0,0 +1,208 @@ +# Writing a Python SQL engine from scratch +[Toby Mao](https://www.linkedin.com/in/toby-mao/) + +## Introduction + +When I first started writing SQLGlot in early 2021, my goal was just to translate SQL queries from SparkSQL to Presto and vice versa. However, over the last year and a half, I've ended up with a full-fledged SQL engine. SQLGlot can now parse and transpile between [18 SQL dialects](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) and can execute all 24 [TPC-H](https://www.tpc.org/tpch/) SQL queries. The parser and engine are all written from scratch using Python. + +This post will cover [why](#why) I went through the effort of creating a Python SQL engine and [how](#how) a simple query goes from a string to actually transforming data. The following steps are briefly summarized: + +* [Tokenizing](#tokenizing) +* [Parsing](#parsing) +* [Optimizing](#optimizing) +* [Planning](#planning) +* [Executing](#executing) + +## Why? +I started working on SQLGlot because of my work on the [experimentation and metrics platform](https://netflixtechblog.com/reimagining-experimentation-analysis-at-netflix-71356393af21) at Netflix, where I built tools that allowed data scientists to define and compute SQL-based metrics. Netflix relied on multiple engines to query data (Spark, Presto, and Druid), so my team built the metrics platform around [PyPika](https://github.com/kayak/pypika), a Python SQL query builder. This way, definitions could be reused across multiple engines. However, it became quickly apparent that writing python code to programatically generate SQL was challenging for data scientists, especially those with academic backgrounds, since they were mostly familiar with R and SQL. At the time, the only Python SQL parser was [sqlparse]([https://github.com/andialbrecht/sqlparse), which is not actually a parser but a tokenizer, so having users write raw SQL into the platform wasn't really an option. Some time later, I randomly stumbled across [Crafting Interpreters](https://craftinginterpreters.com/) and realized that I could use it as a guide towards creating my own SQL parser/transpiler. + +Why did I do this? Isn't a Python SQL engine going to be extremely slow? + +The main reason why I ended up building a SQL engine was...just for **entertainment**. It's been fun learning about all the things required to actually run a SQL query, and seeing it actually work is extremely rewarding. Before SQLGlot, I had zero experience with lexers, parsers, or compilers. + +In terms of practical use cases, I planned to use the Python SQL engine for unit testing SQL pipelines. Big data pipelines are tough to test because many of the engines are not open source and cannot be run locally. With SQLGlot, you can take a SQL query targeting a warehouse such as [Snowflake](https://www.snowflake.com/en/) and seamlessly run it in CI on mock Python data. It's easy to mock data and create arbitrary [UDFs](https://en.wikipedia.org/wiki/User-defined_function) because everything is just Python. Although the implementation is slow and unsuitable for large amounts of data (> 1 millon rows), there's very little overhead/startup and you can run queries on test data in a couple of milliseconds. + +Finally, the components that have been built to support execution can be used as a **foundation** for a faster engine. I'm inspired by what [Apache Calcite](https://github.com/apache/calcite) has done for the JVM world. Even though Python is commonly used for data, there hasn't been a Calcite for Python. So, you could say that SQLGlot aims to be that framework. For example, it wouldn't take much work to replace the Python execution engine with numpy/pandas/arrow to become a respectably-performing query engine. The implementation would be able to leverage the parser, optimizer, and logical planner, only needing to implement physical execution. There is a lot of work in the Python ecosystem around high performance vectorized computation, which I think could benefit from a pure Python-based [AST](https://en.wikipedia.org/wiki/Abstract_syntax_tree)/[plan](https://en.wikipedia.org/wiki/Query_plan). Parsing and planning doesn't have to be fast when the bottleneck of running queries is processing terabytes of data. So, having a Python-based ecosystem around SQL is beneficial given the ease of development in Python, despite not having bare metal performance. + +Parts of SQLGlot's toolkit are being used today by the following: + +* [Ibis](https://github.com/ibis-project/ibis): A Python library that provides a lightweight, universal interface for data wrangling. + - Uses the Python SQL expression builder and leverages the optimizer/planner to convert SQL into dataframe operations. +* [mysql-mimic](https://github.com/kelsin/mysql-mimic): Pure-Python implementation of the MySQL server wire protocol + - Parses / transforms SQL and executes INFORMATION_SCHEMA queries. +* [Quokka](https://github.com/marsupialtail/quokka): Push-based vectorized query engine + - Parse and optimizes SQL. +* [Splink](https://github.com/moj-analytical-services/splink): Fast, accurate and scalable probabilistic data linkage using your choice of SQL backend. + - Transpiles queries. + +## How? + +There are many steps involved with actually running a simple query like: + +```sql +SELECT + bar.a, + b + 1 AS b +FROM bar +JOIN baz + ON bar.a = baz.a +WHERE bar.a > 1 +``` + +In this post, I'll walk through all the steps SQLGlot takes to run this query over Python objects. + +## Tokenizing + +The first step is to convert the sql string into a list of tokens. SQLGlot's tokenizer is quite simple and can be found [here](https://github.com/tobymao/sqlglot/blob/main/sqlglot/tokens.py). In a while loop, it checks each character and either appends the character to the current token, or makes a new token. + +Running the SQLGlot tokenizer shows the output. + +![Tokenizer Output](python_sql_engine_images/tokenizer.png) + +Each keyword has been converted to a SQLGlot Token object. Each token has some metadata associated with it, like line/column information for error messages. Comments are also a part of the token, so that comments can be preserved. + +## Parsing + +Once a SQL statement is tokenized, we don't need to worry about white space and other formatting, so it's easier to work with. We can now convert the list of tokens into an AST. The SQLGlot [parser](https://github.com/tobymao/sqlglot/blob/main/sqlglot/parser.py) is a handwritten [recursive descent](https://en.wikipedia.org/wiki/Recursive_descent_parser) parser. + +Similar to the tokenizer, it consumes the tokens sequentially, but it instead uses a recursive algorithm. The tokens are converted into a single AST node that presents the SQL query. The SQLGlot parser was designed to support various dialects, so it contains many options for overriding parsing functionality. + +![Parser Output](python_sql_engine_images/parser.png) + +The AST is a generic representation of a given SQL query. Each dialect can override or implement its own generator, which can convert an AST object into syntatically-correct SQL. + +## Optimizing + +Once we have our AST, we can transform it into an equivalent query that produces the same results more efficiently. When optimizing queries, most engines first convert the AST into a logical plan and then optimize the plan. However, I chose to **optimize the AST directly** for the following reasons: + +1. It's easier to debug and [validate](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer) the optimizations when the input and output are both SQL. + +2. Rules can be applied a la carte to transform SQL into a more desireable form. + +3. I wanted a way to generate 'canonical sql'. Having a canonical representation of SQL is useful for understanding if two queries are semantically equivalent (e.g. `SELECT 1 + 1` and `SELECT 2`). + +I've yet to find another engine that takes this approach, but I'm quite happy with this decision. The optimizer currently does not perform any "physical optimizations" such as join reordering. Those are left to the execution layer, as additional statistics and information could become relevant. + +![Optimizer Output](python_sql_engine_images/optimizer.png) + +The optimizer currently has [17 rules](https://github.com/tobymao/sqlglot/tree/main/sqlglot/optimizer). Each of these rules is applied, transforming the AST in place. The combination of these rules creates "canonical" sql that can then be more easily converted into a logical plan and executed. + +Some example rules are: + +### qualify\_tables and qualify_columns +- Adds all db/catalog qualifiers to tables and forces an alias. +- Ensure each column is unambiguous and expand stars. + +```sql +SELECT * FROM x; + +SELECT "db"."x" AS "x"; +``` + +### simplify +Boolean and math simplification. Check out all the [test cases](https://github.com/tobymao/sqlglot/blob/main/tests/fixtures/optimizer/simplify.sql). + +```sql +((NOT FALSE) AND (x = x)) AND (TRUE OR 1 <> 3); +x = x; + +1 + 1; +2; +``` + +### normalize +Attempts to convert all predicates into [conjunctive normal form](https://en.wikipedia.org/wiki/Conjunctive_normal_form). + +```sql +-- DNF +(A AND B) OR (B AND C AND D); + +-- CNF +(A OR C) AND (A OR D) AND B; +``` + +### unnest\_subqueries +Converts subqueries in predicates into joins. + +```sql +-- The subquery can be converted into a left join +SELECT * +FROM x AS x +WHERE ( + SELECT y.a AS a + FROM y AS y + WHERE x.a = y.a +) = 1; + +SELECT * +FROM x AS x +LEFT JOIN ( + SELECT y.a AS a + FROM y AS y + WHERE TRUE + GROUP BY y.a +) AS "_u_0" + ON x.a = "_u_0".a +WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL) +``` + +### pushdown_predicates +Push down filters into the innermost query. +```sql +SELECT * +FROM ( + SELECT * + FROM x AS x +) AS y +WHERE y.a = 1; + +SELECT * +FROM ( + SELECT * + FROM x AS x + WHERE y.a = 1 +) AS y WHERE TRUE +``` + +### annotate_types +Infer all types throughout the AST given schema information and function type definitions. + +## Planning +After the SQL AST has been "optimized", it's much easier to [convert into a logical plan](https://github.com/tobymao/sqlglot/blob/main/sqlglot/planner.py). The AST is traversed and converted into a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph) consisting of one of five steps. The different steps are: + +### Scan +Selects columns from a table, applies projections, and finally filters the table. + +### Sort +Sorts a table for order by expressions. + +### Set +Applies the operators union/union all/except/intersect. + +### Aggregate +Applies an aggregation/group by. + +### Join +Joins multiple tables together. + +![Planner Output](python_sql_engine_images/planner.png) + +The logical plan is quite simple and contains the information required to convert it into a physical plan (execution). + +## Executing +Finally, we can actually execute the SQL query. The [Python engine](https://github.com/tobymao/sqlglot/blob/main/sqlglot/executor/python.py) is not fast, but it's very small (~400 LOC)! It iterates the DAG with a queue and runs each step, passing each intermediary table to the next step. + +In order to keep things simple, it evaluates expressions with `eval`. Because SQLGlot was built primarily to be a transpiler, it was simple to create a "Python SQL" dialect. So a SQL expression `x + 1` can just be converted into `scope['x'] + 1`. + +![Executor Output](python_sql_engine_images/executor.png) + +## What's next +SQLGlot's main focus will always be on parsing/transpiling, but I plan to continue development on the execution engine. I'd like to pass [TPC-DS](https://www.tpc.org/tpcds/). If someone doesn't beat me to it, I may even take a stab at writing a Pandas/Arrow execution engine. + +I'm hoping that over time, SQLGlot will spark the Python SQL ecosystem just like Calcite has for Java. + +## Special thanks +SQLGlot would not be what it is without it's core contributors. In particular, the execution engine would not exist without [Barak Alon](https://github.com/barakalon) and [George Sittas](https://github.com/GeorgeSittas). + +## Get in touch +If you'd like to chat more about SQLGlot, please join my [Slack Channel](https://join.slack.com/t/tobiko-data/shared_invite/zt-1ma66d79v-a4dbf4DUpLAQJ8ptQrJygg)! diff --git a/posts/python_sql_engine_images/executor.png b/posts/python_sql_engine_images/executor.png Binary files differnew file mode 100644 index 0000000..e6e2c9c --- /dev/null +++ b/posts/python_sql_engine_images/executor.png diff --git a/posts/python_sql_engine_images/optimizer.png b/posts/python_sql_engine_images/optimizer.png Binary files differnew file mode 100644 index 0000000..879d7b1 --- /dev/null +++ b/posts/python_sql_engine_images/optimizer.png diff --git a/posts/python_sql_engine_images/parser.png b/posts/python_sql_engine_images/parser.png Binary files differnew file mode 100644 index 0000000..15e3c3a --- /dev/null +++ b/posts/python_sql_engine_images/parser.png diff --git a/posts/python_sql_engine_images/planner.png b/posts/python_sql_engine_images/planner.png Binary files differnew file mode 100644 index 0000000..c4a44b7 --- /dev/null +++ b/posts/python_sql_engine_images/planner.png diff --git a/posts/python_sql_engine_images/tokenizer.png b/posts/python_sql_engine_images/tokenizer.png Binary files differnew file mode 100644 index 0000000..b1147df --- /dev/null +++ b/posts/python_sql_engine_images/tokenizer.png diff --git a/run_checks.sh b/run_checks.sh deleted file mode 100755 index 187e6b9..0000000 --- a/run_checks.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -e -[[ -z "${GITHUB_ACTIONS}" ]] && RETURN_ERROR_CODE='' || RETURN_ERROR_CODE='--check' -TARGETS="sqlglot/ tests/" -python -m mypy $TARGETS -python -m autoflake -i -r ${RETURN_ERROR_CODE} $TARGETS -python -m isort $TARGETS -python -m black --line-length 100 ${RETURN_ERROR_CODE} $TARGETS -python -m unittest @@ -22,6 +22,20 @@ setup( license="MIT", packages=find_packages(include=["sqlglot", "sqlglot.*"]), package_data={"sqlglot": ["py.typed"]}, + extras_require={ + "dev": [ + "autoflake", + "black", + "duckdb", + "isort", + "mypy", + "pandas", + "pyspark", + "python-dateutil", + "pdoc", + "pre-commit", + ], + }, classifiers=[ "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index e829517..04c3195 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,4 +1,6 @@ -"""## Python SQL parser, transpiler and optimizer.""" +""" +.. include:: ../README.md +""" from __future__ import annotations @@ -30,7 +32,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "10.2.9" +__version__ = "10.4.2" pretty = False diff --git a/sqlglot/__main__.py b/sqlglot/__main__.py index 42a54bc..f9613b2 100644 --- a/sqlglot/__main__.py +++ b/sqlglot/__main__.py @@ -1,9 +1,15 @@ import argparse +import sys import sqlglot parser = argparse.ArgumentParser(description="Transpile SQL") -parser.add_argument("sql", metavar="sql", type=str, help="SQL string to transpile") +parser.add_argument( + "sql", + metavar="sql", + type=str, + help="SQL statement(s) to transpile, or - to parse stdin.", +) parser.add_argument( "--read", dest="read", @@ -48,14 +54,20 @@ parser.add_argument( args = parser.parse_args() error_level = sqlglot.ErrorLevel[args.error_level.upper()] +sql = sys.stdin.read() if args.sql == "-" else args.sql + if args.parse: sqls = [ repr(expression) - for expression in sqlglot.parse(args.sql, read=args.read, error_level=error_level) + for expression in sqlglot.parse( + sql, + read=args.read, + error_level=error_level, + ) ] else: sqls = sqlglot.transpile( - args.sql, + sql, read=args.read, write=args.write, identify=args.identify, diff --git a/sqlglot/dataframe/__init__.py b/sqlglot/dataframe/__init__.py index e69de29..a57e990 100644 --- a/sqlglot/dataframe/__init__.py +++ b/sqlglot/dataframe/__init__.py @@ -0,0 +1,3 @@ +""" +.. include:: ./README.md +""" diff --git a/sqlglot/dataframe/sql/_typing.pyi b/sqlglot/dataframe/sql/_typing.pyi index 67c8c09..1682ec1 100644 --- a/sqlglot/dataframe/sql/_typing.pyi +++ b/sqlglot/dataframe/sql/_typing.pyi @@ -9,18 +9,8 @@ if t.TYPE_CHECKING: from sqlglot.dataframe.sql.column import Column from sqlglot.dataframe.sql.types import StructType -ColumnLiterals = t.TypeVar( - "ColumnLiterals", - bound=t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], -) -ColumnOrName = t.TypeVar("ColumnOrName", bound=t.Union[Column, str]) -ColumnOrLiteral = t.TypeVar( - "ColumnOrLiteral", - bound=t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime], -) -SchemaInput = t.TypeVar( - "SchemaInput", bound=t.Union[str, t.List[str], StructType, t.Dict[str, str]] -) -OutputExpressionContainer = t.TypeVar( - "OutputExpressionContainer", bound=t.Union[exp.Select, exp.Create, exp.Insert] -) +ColumnLiterals = t.Union[str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +ColumnOrName = t.Union[Column, str] +ColumnOrLiteral = t.Union[Column, str, float, int, bool, t.List, t.Tuple, datetime.date, datetime.datetime] +SchemaInput = t.Union[str, t.List[str], StructType, t.Dict[str, t.Optional[str]]] +OutputExpressionContainer = t.Union[exp.Select, exp.Create, exp.Insert] diff --git a/sqlglot/dataframe/sql/dataframe.py b/sqlglot/dataframe/sql/dataframe.py index 3c45741..a17bb9d 100644 --- a/sqlglot/dataframe/sql/dataframe.py +++ b/sqlglot/dataframe/sql/dataframe.py @@ -634,7 +634,7 @@ class DataFrame: all_columns = self._get_outer_select_columns(new_df.expression) all_column_mapping = {column.alias_or_name: column for column in all_columns} if isinstance(value, dict): - values = value.values() + values = list(value.values()) columns = self._ensure_and_normalize_cols(list(value)) if not columns: columns = self._ensure_and_normalize_cols(subset) if subset else all_columns diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 6be68ac..d10cc54 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -1,11 +1,15 @@ +"""Supports BigQuery Standard SQL.""" + from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + datestrtodate_sql, inline_array_sql, no_ilike_sql, rename_func, + timestrtotime_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -120,13 +124,12 @@ class BigQuery(Dialect): "NOT DETERMINISTIC": TokenType.VOLATILE, "QUALIFY": TokenType.QUALIFY, "UNKNOWN": TokenType.NULL, - "WINDOW": TokenType.WINDOW, } KEYWORDS.pop("DIV") class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "DATE_TRUNC": _date_trunc, "DATE_ADD": _date_add(exp.DateAdd), "DATETIME_ADD": _date_add(exp.DatetimeAdd), @@ -144,31 +147,33 @@ class BigQuery(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, # type: ignore "ARRAY": lambda self: self.expression(exp.Array, expressions=[self._parse_statement()]), } FUNCTION_PARSERS.pop("TRIM") NO_PAREN_FUNCTIONS = { - **parser.Parser.NO_PAREN_FUNCTIONS, + **parser.Parser.NO_PAREN_FUNCTIONS, # type: ignore TokenType.CURRENT_DATETIME: exp.CurrentDatetime, TokenType.CURRENT_TIME: exp.CurrentTime, } NESTED_TYPE_TOKENS = { - *parser.Parser.NESTED_TYPE_TOKENS, + *parser.Parser.NESTED_TYPE_TOKENS, # type: ignore TokenType.TABLE, } class Generator(generator.Generator): TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.DateAdd: _date_add_sql("DATE", "ADD"), exp.DateSub: _date_add_sql("DATE", "SUB"), exp.DatetimeAdd: _date_add_sql("DATETIME", "ADD"), exp.DatetimeSub: _date_add_sql("DATETIME", "SUB"), exp.DateDiff: lambda self, e: f"DATE_DIFF({self.sql(e, 'this')}, {self.sql(e, 'expression')}, {self.sql(e.args.get('unit', 'DAY'))})", + exp.DateStrToDate: datestrtodate_sql, + exp.GroupConcat: rename_func("STRING_AGG"), exp.ILike: no_ilike_sql, exp.IntDiv: rename_func("DIV"), exp.StrToTime: lambda self, e: f"PARSE_TIMESTAMP({self.format_time(e)}, {self.sql(e, 'this')})", @@ -176,6 +181,7 @@ class BigQuery(Dialect): exp.TimeSub: _date_add_sql("TIME", "SUB"), exp.TimestampAdd: _date_add_sql("TIMESTAMP", "ADD"), exp.TimestampSub: _date_add_sql("TIMESTAMP", "SUB"), + exp.TimeStrToTime: timestrtotime_sql, exp.VariancePop: rename_func("VAR_POP"), exp.Values: _derived_table_values_to_unnest, exp.ReturnsProperty: _returnsproperty_sql, @@ -188,7 +194,7 @@ class BigQuery(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "INT64", exp.DataType.Type.SMALLINT: "INT64", exp.DataType.Type.INT: "INT64", diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index cbed72e..7136340 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -35,13 +35,13 @@ class ClickHouse(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "MAP": parse_var_map, } - JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} + JOIN_KINDS = {*parser.Parser.JOIN_KINDS, TokenType.ANY, TokenType.ASOF} # type: ignore - TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} + TABLE_ALIAS_TOKENS = {*parser.Parser.TABLE_ALIAS_TOKENS} - {TokenType.ANY} # type: ignore def _parse_table(self, schema=False): this = super()._parse_table(schema) @@ -55,7 +55,7 @@ class ClickHouse(Dialect): STRUCT_DELIMITER = ("(", ")") TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.NULLABLE: "Nullable", exp.DataType.Type.DATETIME: "DateTime64", exp.DataType.Type.MAP: "Map", @@ -70,7 +70,7 @@ class ClickHouse(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.Final: lambda self, e: f"{self.sql(e, 'this')} FINAL", diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index c87f8d8..e788852 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -198,7 +198,7 @@ class Dialect(metaclass=_Dialect): def rename_func(name): def _rename(self, expression): args = flatten(expression.args.values()) - return f"{name}({self.format_args(*args)})" + return f"{self.normalize_func(name)}({self.format_args(*args)})" return _rename @@ -217,11 +217,11 @@ def if_sql(self, expression): def arrow_json_extract_sql(self, expression): - return f"{self.sql(expression, 'this')}->{self.sql(expression, 'path')}" + return self.binary(expression, "->") def arrow_json_extract_scalar_sql(self, expression): - return f"{self.sql(expression, 'this')}->>{self.sql(expression, 'path')}" + return self.binary(expression, "->>") def inline_array_sql(self, expression): @@ -373,3 +373,11 @@ def strposition_to_local_sql(self, expression): expression.args.get("substr"), expression.this, expression.args.get("position") ) return f"LOCATE({args})" + + +def timestrtotime_sql(self, expression: exp.TimeStrToTime) -> str: + return f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" + + +def datestrtodate_sql(self, expression: exp.DateStrToDate) -> str: + return f"CAST({self.sql(expression, 'this')} AS DATE)" diff --git a/sqlglot/dialects/drill.py b/sqlglot/dialects/drill.py index 358eced..4e3c0e1 100644 --- a/sqlglot/dialects/drill.py +++ b/sqlglot/dialects/drill.py @@ -6,13 +6,14 @@ from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, create_with_partitions_sql, + datestrtodate_sql, format_time_lambda, no_pivot_sql, no_trycast_sql, rename_func, str_position_sql, + timestrtotime_sql, ) -from sqlglot.dialects.postgres import _lateral_sql def _to_timestamp(args): @@ -117,14 +118,14 @@ class Drill(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "TO_TIMESTAMP": exp.TimeStrToTime.from_arg_list, "TO_CHAR": format_time_lambda(exp.TimeToStr, "drill"), } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", @@ -139,14 +140,13 @@ class Drill(Dialect): ROOT_PROPERTIES = {exp.PartitionedByProperty} TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", - exp.Lateral: _lateral_sql, exp.ArrayContains: rename_func("REPEATED_CONTAINS"), exp.ArraySize: rename_func("REPEATED_COUNT"), exp.Create: create_with_partitions_sql, exp.DateAdd: _date_add_sql("ADD"), - exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateStrToDate: datestrtodate_sql, exp.DateSub: _date_add_sql("SUB"), exp.DateToDi: lambda self, e: f"CAST(TO_DATE({self.sql(e, 'this')}, {Drill.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS VARCHAR), {Drill.dateint_format})", @@ -160,7 +160,7 @@ class Drill(Dialect): exp.StrToDate: _str_to_date, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index f1da72b..81941f7 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -6,6 +6,7 @@ from sqlglot.dialects.dialect import ( approx_count_distinct_sql, arrow_json_extract_scalar_sql, arrow_json_extract_sql, + datestrtodate_sql, format_time_lambda, no_pivot_sql, no_properties_sql, @@ -13,6 +14,7 @@ from sqlglot.dialects.dialect import ( no_tablesample_sql, rename_func, str_position_sql, + timestrtotime_sql, ) from sqlglot.helper import seq_get from sqlglot.tokens import TokenType @@ -83,11 +85,12 @@ class DuckDB(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, ":=": TokenType.EQ, + "CHARACTER VARYING": TokenType.VARCHAR, } class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "ARRAY_LENGTH": exp.ArraySize.from_arg_list, "ARRAY_SORT": exp.SortArray.from_arg_list, @@ -119,16 +122,18 @@ class DuckDB(Dialect): STRUCT_DELIMITER = ("(", ")") TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: rename_func("LIST_VALUE"), + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else rename_func("LIST_VALUE")(self, e), exp.ArraySize: rename_func("ARRAY_LENGTH"), exp.ArraySort: _array_sort_sql, exp.ArraySum: rename_func("LIST_SUM"), exp.DataType: _datatype_sql, exp.DateAdd: _date_add, exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.format_args(e.args.get("unit") or "'day'", e.expression, e.this)})""", - exp.DateStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", + exp.DateStrToDate: datestrtodate_sql, exp.DateToDi: lambda self, e: f"CAST(STRFTIME({self.sql(e, 'this')}, {DuckDB.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"CAST(STRPTIME(CAST({self.sql(e, 'this')} AS TEXT), {DuckDB.dateint_format}) AS DATE)", exp.Explode: rename_func("UNNEST"), @@ -136,6 +141,7 @@ class DuckDB(Dialect): exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: arrow_json_extract_sql, exp.JSONBExtractScalar: arrow_json_extract_scalar_sql, + exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Properties: no_properties_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), @@ -150,7 +156,7 @@ class DuckDB(Dialect): exp.Struct: _struct_pack_sql, exp.TableSample: no_tablesample_sql, exp.TimeStrToDate: lambda self, e: f"CAST({self.sql(e, 'this')} AS DATE)", - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"EPOCH(CAST({self.sql(e, 'this')} AS TIMESTAMP))", exp.TimeToStr: lambda self, e: f"STRFTIME({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("EPOCH"), @@ -163,7 +169,7 @@ class DuckDB(Dialect): } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.VARCHAR: "TEXT", exp.DataType.Type.NVARCHAR: "TEXT", } diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 8d6e1ae..088555c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import ( rename_func, strposition_to_local_sql, struct_extract_sql, + timestrtotime_sql, var_map_sql, ) from sqlglot.helper import seq_get @@ -197,7 +198,7 @@ class Hive(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( @@ -217,7 +218,12 @@ class Hive(Dialect): ), unit=exp.Literal.string("DAY"), ), - "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), + "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( + [ + exp.TimeStrToTime(this=seq_get(args, 0)), + seq_get(args, 1), + ] + ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, @@ -240,7 +246,7 @@ class Hive(Dialect): } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, # type: ignore TokenType.SERDE_PROPERTIES: lambda self: exp.SerdeProperties( expressions=self._parse_wrapped_csv(self._parse_property) ), @@ -248,14 +254,14 @@ class Hive(Dialect): class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.DATETIME: "TIMESTAMP", exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, @@ -294,7 +300,7 @@ class Hive(Dialect): exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"USING {self.sql(e, 'this')}", exp.TimeStrToDate: rename_func("TO_DATE"), - exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"), exp.TimeToStr: _time_to_str, exp.TimeToUnix: rename_func("UNIX_TIMESTAMP"), diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index 7627b6e..0fd7992 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -161,8 +161,6 @@ class MySQL(Dialect): "_UCS2": TokenType.INTRODUCER, "_UJIS": TokenType.INTRODUCER, # https://dev.mysql.com/doc/refman/8.0/en/string-literals.html - "N": TokenType.INTRODUCER, - "n": TokenType.INTRODUCER, "_UTF8": TokenType.INTRODUCER, "_UTF16": TokenType.INTRODUCER, "_UTF16LE": TokenType.INTRODUCER, @@ -175,10 +173,10 @@ class MySQL(Dialect): COMMANDS = tokens.Tokenizer.COMMANDS - {TokenType.SET, TokenType.SHOW} class Parser(parser.Parser): - FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} + FUNC_TOKENS = {*parser.Parser.FUNC_TOKENS, TokenType.SCHEMA} # type: ignore FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "DATE_ADD": _date_add(exp.DateAdd), "DATE_SUB": _date_add(exp.DateSub), "STR_TO_DATE": _str_to_date, @@ -190,7 +188,7 @@ class MySQL(Dialect): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, # type: ignore "GROUP_CONCAT": lambda self: self.expression( exp.GroupConcat, this=self._parse_lambda(), @@ -199,12 +197,12 @@ class MySQL(Dialect): } PROPERTY_PARSERS = { - **parser.Parser.PROPERTY_PARSERS, + **parser.Parser.PROPERTY_PARSERS, # type: ignore TokenType.ENGINE: lambda self: self._parse_property_assignment(exp.EngineProperty), } STATEMENT_PARSERS = { - **parser.Parser.STATEMENT_PARSERS, + **parser.Parser.STATEMENT_PARSERS, # type: ignore TokenType.SHOW: lambda self: self._parse_show(), TokenType.SET: lambda self: self._parse_set(), } @@ -429,7 +427,7 @@ class MySQL(Dialect): NULL_ORDERING_SUPPORTED = False TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.ILike: no_ilike_sql, diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index f507513..af3d353 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -39,13 +39,13 @@ class Oracle(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "DECODE": exp.Matches.from_arg_list, } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "NUMBER", exp.DataType.Type.SMALLINT: "NUMBER", exp.DataType.Type.INT: "NUMBER", @@ -60,7 +60,7 @@ class Oracle(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore exp.ILike: no_ilike_sql, exp.Limit: _limit_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index f276af1..a092cad 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, str_position_sql, ) +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess +DATE_DIFF_FACTOR = { + "MICROSECOND": " * 1000000", + "MILLISECOND": " * 1000", + "SECOND": "", + "MINUTE": " / 60", + "HOUR": " / 3600", + "DAY": " / 86400", +} + def _date_add_sql(kind): def func(self, expression): @@ -34,16 +44,30 @@ def _date_add_sql(kind): return func -def _lateral_sql(self, expression): - this = self.sql(expression, "this") - if isinstance(expression.this, exp.Subquery): - return f"LATERAL{self.sep()}{this}" - alias = expression.args["alias"] - table = alias.name - table = f" {table}" if table else table - columns = self.expressions(alias, key="columns", flat=True) - columns = f" AS {columns}" if columns else "" - return f"LATERAL{self.sep()}{this}{table}{columns}" +def _date_diff_sql(self, expression): + unit = expression.text("unit").upper() + factor = DATE_DIFF_FACTOR.get(unit) + + end = f"CAST({expression.this} AS TIMESTAMP)" + start = f"CAST({expression.expression} AS TIMESTAMP)" + + if factor is not None: + return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)" + + age = f"AGE({end}, {start})" + + if unit == "WEEK": + extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + elif unit == "MONTH": + extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" + elif unit == "QUARTER": + extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" + elif unit == "YEAR": + extract = f"EXTRACT(year FROM {age})" + else: + self.unsupported(f"Unsupported DATEDIFF unit {unit}") + + return f"CAST({extract} AS BIGINT)" def _substring_sql(self, expression): @@ -141,7 +165,7 @@ def _serial_to_generated(expression): def _to_timestamp(args): # TO_TIMESTAMP accepts either a single double argument or (text, text) - if len(args) == 1 and args[0].is_number: + if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) # https://www.postgresql.org/docs/current/functions-formatting.html @@ -211,11 +235,16 @@ class Postgres(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "~~": TokenType.LIKE, + "~~*": TokenType.ILIKE, + "~*": TokenType.IRLIKE, + "~": TokenType.RLIKE, "ALWAYS": TokenType.ALWAYS, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "BY DEFAULT": TokenType.BY_DEFAULT, + "CHARACTER VARYING": TokenType.VARCHAR, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, @@ -233,6 +262,7 @@ class Postgres(Dialect): "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, "UUID": TokenType.UUID, + "CSTRING": TokenType.PSEUDO_TYPE, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } @@ -244,17 +274,16 @@ class Postgres(Dialect): class Parser(parser.Parser): STRICT_CAST = False - LATERAL_FUNCTION_AS_VIEW = True FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", @@ -264,7 +293,7 @@ class Postgres(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ColumnDef: preprocess( [ _auto_increment_to_serial, @@ -274,13 +303,16 @@ class Postgres(Dialect): ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", - exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", + exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), + exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), + exp.JSONBContains: lambda self, e: self.binary(e, "?"), exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), - exp.Lateral: _lateral_sql, + exp.DateDiff: _date_diff_sql, + exp.RegexpLike: lambda self, e: self.binary(e, "~"), + exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, @@ -291,5 +323,7 @@ class Postgres(Dialect): exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, exp.GroupConcat: _string_agg_sql, - exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", } diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 1a09037..e16ea1d 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -10,6 +10,7 @@ from sqlglot.dialects.dialect import ( rename_func, str_position_sql, struct_extract_sql, + timestrtotime_sql, ) from sqlglot.dialects.mysql import MySQL from sqlglot.errors import UnsupportedError @@ -38,10 +39,6 @@ def _datatype_sql(self, expression): return sql -def _date_parse_sql(self, expression): - return f"DATE_PARSE({self.sql(expression, 'this')}, '%Y-%m-%d %H:%i:%s')" - - def _explode_to_unnest_sql(self, expression): if isinstance(expression.this, (exp.Explode, exp.Posexplode)): return self.sql( @@ -137,7 +134,7 @@ class Presto(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "APPROX_DISTINCT": exp.ApproxDistinct.from_arg_list, "CARDINALITY": exp.ArraySize.from_arg_list, "CONTAINS": exp.ArrayContains.from_arg_list, @@ -174,7 +171,7 @@ class Presto(Dialect): ROOT_PROPERTIES = {exp.SchemaCommentProperty} TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.BINARY: "VARBINARY", @@ -184,7 +181,7 @@ class Presto(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore exp.ApproxDistinct: _approx_distinct_sql, exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", @@ -224,8 +221,8 @@ class Presto(Dialect): exp.StructExtract: struct_extract_sql, exp.TableFormatProperty: lambda self, e: f"TABLE_FORMAT='{e.name.upper()}'", exp.FileFormatProperty: lambda self, e: f"FORMAT='{e.name.upper()}'", - exp.TimeStrToDate: _date_parse_sql, - exp.TimeStrToTime: _date_parse_sql, + exp.TimeStrToDate: timestrtotime_sql, + exp.TimeStrToTime: timestrtotime_sql, exp.TimeStrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {Presto.time_format}))", exp.TimeToStr: lambda self, e: f"DATE_FORMAT({self.sql(e, 'this')}, {self.format_time(e)})", exp.TimeToUnix: rename_func("TO_UNIXTIME"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 55ed0a6..27dfb93 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -36,7 +36,6 @@ class Redshift(Postgres): "TIMETZ": TokenType.TIMESTAMPTZ, "UNLOAD": TokenType.COMMAND, "VARBYTE": TokenType.VARBINARY, - "SIMILAR TO": TokenType.SIMILAR_TO, } class Generator(Postgres.Generator): diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 75dc9dc..77b09e9 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -3,13 +3,15 @@ from __future__ import annotations from sqlglot import exp, generator, parser, tokens from sqlglot.dialects.dialect import ( Dialect, + datestrtodate_sql, format_time_lambda, inline_array_sql, rename_func, + timestrtotime_sql, var_map_sql, ) from sqlglot.expressions import Literal -from sqlglot.helper import seq_get +from sqlglot.helper import flatten, seq_get from sqlglot.tokens import TokenType @@ -183,7 +185,7 @@ class Snowflake(Dialect): class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] - ESCAPES = ["\\"] + ESCAPES = ["\\", "'"] SINGLE_TOKENS = { **tokens.Tokenizer.SINGLE_TOKENS, @@ -206,9 +208,10 @@ class Snowflake(Dialect): CREATE_TRANSIENT = True TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.Array: inline_array_sql, exp.ArrayConcat: rename_func("ARRAY_CAT"), + exp.DateStrToDate: datestrtodate_sql, exp.DataType: _datatype_sql, exp.If: rename_func("IFF"), exp.Map: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"), @@ -218,13 +221,14 @@ class Snowflake(Dialect): exp.Matches: rename_func("DECODE"), exp.StrPosition: rename_func("POSITION"), exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.TimeStrToTime: timestrtotime_sql, exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.Trim: lambda self, e: f"TRIM({self.format_args(e.this, e.expression)})", exp.UnixToTime: _unix_to_time_sql, } TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ", } @@ -246,3 +250,47 @@ class Snowflake(Dialect): if not expression.args.get("distinct", False): self.unsupported("INTERSECT with All is not supported in Snowflake") return super().intersect_op(expression) + + def values_sql(self, expression: exp.Values) -> str: + """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted. + + We also want to make sure that after we find matches where we need to unquote a column that we prevent users + from adding quotes to the column by using the `identify` argument when generating the SQL. + """ + alias = expression.args.get("alias") + if alias and alias.args.get("columns"): + expression = expression.transform( + lambda node: exp.Identifier(**{**node.args, "quoted": False}) + if isinstance(node, exp.Identifier) + and isinstance(node.parent, exp.TableAlias) + and node.arg_key == "columns" + else node, + ) + return self.no_identify(lambda: super(self.__class__, self).values_sql(expression)) + return super().values_sql(expression) + + def select_sql(self, expression: exp.Select) -> str: + """Due to a bug in Snowflake we want to make sure that all columns in a VALUES table alias are unquoted and also + that all columns in a SELECT are unquoted. We also want to make sure that after we find matches where we need + to unquote a column that we prevent users from adding quotes to the column by using the `identify` argument when + generating the SQL. + + Note: We make an assumption that any columns referenced in a VALUES expression should be unquoted throughout the + expression. This might not be true in a case where the same column name can be sourced from another table that can + properly quote but should be true in most cases. + """ + values_expressions = expression.find_all(exp.Values) + values_identifiers = set( + flatten( + v.args.get("alias", exp.Alias()).args.get("columns", []) + for v in values_expressions + ) + ) + if values_identifiers: + expression = expression.transform( + lambda node: exp.Identifier(**{**node.args, "quoted": False}) + if isinstance(node, exp.Identifier) and node in values_identifiers + else node, + ) + return self.no_identify(lambda: super(self.__class__, self).select_sql(expression)) + return super().select_sql(expression) diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 16083d1..7f05dea 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -76,7 +76,7 @@ class Spark(Hive): } FUNCTION_PARSERS = { - **parser.Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, # type: ignore "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -87,6 +87,16 @@ class Spark(Hive): "SHUFFLE_REPLICATE_NL": lambda self: self._parse_join_hint("SHUFFLE_REPLICATE_NL"), } + def _parse_add_column(self): + return self._match_text_seq("ADD", "COLUMNS") and self._parse_schema() + + def _parse_drop_column(self): + return self._match_text_seq("DROP", "COLUMNS") and self.expression( + exp.Drop, + this=self._parse_schema(), + kind="COLUMNS", + ) + class Generator(Hive.Generator): TYPE_MAPPING = { **Hive.Generator.TYPE_MAPPING, # type: ignore diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index bbb752b..a0c4942 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -42,13 +42,13 @@ class SQLite(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "EDITDIST3": exp.Levenshtein.from_arg_list, } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "INTEGER", exp.DataType.Type.TINYINT: "INTEGER", exp.DataType.Type.SMALLINT: "INTEGER", @@ -70,7 +70,7 @@ class SQLite(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ILike: no_ilike_sql, exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, diff --git a/sqlglot/dialects/starrocks.py b/sqlglot/dialects/starrocks.py index 3519c09..01e6357 100644 --- a/sqlglot/dialects/starrocks.py +++ b/sqlglot/dialects/starrocks.py @@ -8,7 +8,7 @@ from sqlglot.dialects.mysql import MySQL class StarRocks(MySQL): class Generator(MySQL.Generator): # type: ignore TYPE_MAPPING = { - **MySQL.Generator.TYPE_MAPPING, + **MySQL.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", exp.DataType.Type.TIMESTAMP: "DATETIME", exp.DataType.Type.TIMESTAMPTZ: "DATETIME", diff --git a/sqlglot/dialects/tableau.py b/sqlglot/dialects/tableau.py index 63e7275..36c085f 100644 --- a/sqlglot/dialects/tableau.py +++ b/sqlglot/dialects/tableau.py @@ -30,7 +30,7 @@ class Tableau(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "IFNULL": exp.Coalesce.from_arg_list, "COUNTD": lambda args: exp.Count(this=exp.Distinct(expressions=args)), } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index a552e7b..7f0f2d7 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -224,11 +224,7 @@ class TSQL(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] - QUOTES = [ - (prefix + quote, quote) if prefix else quote - for quote in ["'", '"'] - for prefix in ["", "n", "N"] - ] + QUOTES = ["'", '"'] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -253,7 +249,7 @@ class TSQL(Dialect): class Parser(parser.Parser): FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), @@ -314,7 +310,7 @@ class TSQL(Dialect): class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.BOOLEAN: "BIT", exp.DataType.Type.INT: "INTEGER", exp.DataType.Type.DECIMAL: "NUMERIC", diff --git a/sqlglot/diff.py b/sqlglot/diff.py index 758ad1b..fa8bc1b 100644 --- a/sqlglot/diff.py +++ b/sqlglot/diff.py @@ -1,3 +1,7 @@ +""" +.. include:: ../posts/sql_diff.md +""" + from __future__ import annotations import typing as t diff --git a/sqlglot/executor/context.py b/sqlglot/executor/context.py index e9ff75b..8a58287 100644 --- a/sqlglot/executor/context.py +++ b/sqlglot/executor/context.py @@ -29,10 +29,10 @@ class Context: self._table: t.Optional[Table] = None self.range_readers = {name: table.range_reader for name, table in self.tables.items()} self.row_readers = {name: table.reader for name, table in tables.items()} - self.env = {**(env or {}), "scope": self.row_readers} + self.env = {**ENV, **(env or {}), "scope": self.row_readers} def eval(self, code): - return eval(code, ENV, self.env) + return eval(code, self.env) def eval_tuple(self, codes): return tuple(self.eval(code) for code in codes) diff --git a/sqlglot/executor/env.py b/sqlglot/executor/env.py index ad9397e..04dc938 100644 --- a/sqlglot/executor/env.py +++ b/sqlglot/executor/env.py @@ -127,14 +127,16 @@ def interval(this, unit): ENV = { "exp": exp, # aggs - "SUM": filter_nulls(sum), + "ARRAYAGG": list, "AVG": filter_nulls(statistics.fmean if PYTHON_VERSION >= (3, 8) else statistics.mean), # type: ignore "COUNT": filter_nulls(lambda acc: sum(1 for _ in acc), False), "MAX": filter_nulls(max), "MIN": filter_nulls(min), + "SUM": filter_nulls(sum), # scalar functions "ABS": null_if_any(lambda this: abs(this)), "ADD": null_if_any(lambda e, this: e + this), + "ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)), "BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high), "BITWISEAND": null_if_any(lambda this, e: this & e), "BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e), diff --git a/sqlglot/executor/python.py b/sqlglot/executor/python.py index 9f22c45..29848c6 100644 --- a/sqlglot/executor/python.py +++ b/sqlglot/executor/python.py @@ -394,6 +394,18 @@ def _case_sql(self, expression): return chain +def _lambda_sql(self, e: exp.Lambda) -> str: + names = {e.name.lower() for e in e.expressions} + + e = e.transform( + lambda n: exp.Var(this=n.name) + if isinstance(n, exp.Identifier) and n.name.lower() in names + else n + ) + + return f"lambda {self.expressions(e, flat=True)}: {self.sql(e, 'this')}" + + class Python(Dialect): class Tokenizer(tokens.Tokenizer): ESCAPES = ["\\"] @@ -414,6 +426,7 @@ class Python(Dialect): exp.Extract: lambda self, e: f"EXTRACT('{e.name.lower()}', {self.sql(e, 'expression')})", exp.In: lambda self, e: f"{self.sql(e, 'this')} in ({self.expressions(e, flat=True)})", exp.Is: lambda self, e: self.binary(e, "is"), + exp.Lambda: _lambda_sql, exp.Not: lambda self, e: f"not {self.sql(e.this)}", exp.Null: lambda *_: "None", exp.Or: lambda self, e: self.binary(e, "or"), diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index aeed218..711ec4b 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -1,6 +1,11 @@ +""" +.. include:: ../pdoc/docs/expressions.md +""" + from __future__ import annotations import datetime +import math import numbers import re import typing as t @@ -682,6 +687,10 @@ class CharacterSet(Expression): class With(Expression): arg_types = {"expressions": True, "recursive": False} + @property + def recursive(self) -> bool: + return bool(self.args.get("recursive")) + class WithinGroup(Expression): arg_types = {"this": True, "expression": False} @@ -724,6 +733,18 @@ class ColumnDef(Expression): "this": True, "kind": True, "constraints": False, + "exists": False, + } + + +class AlterColumn(Expression): + arg_types = { + "this": True, + "dtype": False, + "collate": False, + "using": False, + "default": False, + "drop": False, } @@ -877,6 +898,11 @@ class Introducer(Expression): arg_types = {"this": True, "expression": True} +# national char, like n'utf8' +class National(Expression): + pass + + class LoadData(Expression): arg_types = { "this": True, @@ -894,7 +920,7 @@ class Partition(Expression): class Fetch(Expression): - arg_types = {"direction": False, "count": True} + arg_types = {"direction": False, "count": False} class Group(Expression): @@ -1316,7 +1342,7 @@ QUERY_MODIFIERS = { "group": False, "having": False, "qualify": False, - "window": False, + "windows": False, "distribute": False, "sort": False, "cluster": False, @@ -1353,7 +1379,7 @@ class Union(Subqueryable): Example: >>> select("1").union(select("1")).limit(1).sql() - 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS "_l_0" LIMIT 1' + 'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1' Args: expression (str | int | Expression): the SQL code string to parse. @@ -1889,6 +1915,18 @@ class Select(Subqueryable): **opts, ) + def window(self, *expressions, append=True, dialect=None, copy=True, **opts) -> Select: + return _apply_list_builder( + *expressions, + instance=self, + arg="windows", + append=append, + into=Window, + dialect=dialect, + copy=copy, + **opts, + ) + def distinct(self, distinct=True, copy=True) -> Select: """ Set the OFFSET expression. @@ -2140,6 +2178,11 @@ class DataType(Expression): ) +# https://www.postgresql.org/docs/15/datatype-pseudo.html +class PseudoType(Expression): + pass + + class StructKwarg(Expression): arg_types = {"this": True, "expression": True} @@ -2167,18 +2210,26 @@ class Command(Expression): arg_types = {"this": True, "expression": False} -class Transaction(Command): +class Transaction(Expression): arg_types = {"this": False, "modes": False} -class Commit(Command): +class Commit(Expression): arg_types = {"chain": False} -class Rollback(Command): +class Rollback(Expression): arg_types = {"savepoint": False} +class AlterTable(Expression): + arg_types = { + "this": True, + "actions": True, + "exists": False, + } + + # Binary expressions like (ADD a b) class Binary(Expression): arg_types = {"this": True, "expression": True} @@ -2312,6 +2363,10 @@ class SimilarTo(Binary, Predicate): pass +class Slice(Binary): + arg_types = {"this": False, "expression": False} + + class Sub(Binary): pass @@ -2392,7 +2447,7 @@ class TimeUnit(Expression): class Interval(TimeUnit): - arg_types = {"this": True, "unit": False} + arg_types = {"this": False, "unit": False} class IgnoreNulls(Expression): @@ -2730,8 +2785,11 @@ class Initcap(Func): pass -class JSONExtract(Func): - arg_types = {"this": True, "path": True} +class JSONBContains(Binary): + _sql_names = ["JSONB_CONTAINS"] + + +class JSONExtract(Binary, Func): _sql_names = ["JSON_EXTRACT"] @@ -2776,6 +2834,10 @@ class Log10(Func): pass +class LogicalOr(AggFunc): + _sql_names = ["LOGICAL_OR", "BOOL_OR"] + + class Lower(Func): _sql_names = ["LOWER", "LCASE"] @@ -2846,6 +2908,10 @@ class RegexpLike(Func): arg_types = {"this": True, "expression": True, "flag": False} +class RegexpILike(Func): + arg_types = {"this": True, "expression": True, "flag": False} + + class RegexpSplit(Func): arg_types = {"this": True, "expression": True} @@ -3388,11 +3454,17 @@ def update(table, properties, where=None, from_=None, dialect=None, **opts) -> U ], ) if from_: - update.set("from", maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts)) + update.set( + "from", + maybe_parse(from_, into=From, dialect=dialect, prefix="FROM", **opts), + ) if isinstance(where, Condition): where = Where(this=where) if where: - update.set("where", maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts)) + update.set( + "where", + maybe_parse(where, into=Where, dialect=dialect, prefix="WHERE", **opts), + ) return update @@ -3522,7 +3594,7 @@ def paren(expression) -> Paren: return Paren(this=expression) -SAFE_IDENTIFIER_RE = re.compile(r"^[a-zA-Z][\w]*$") +SAFE_IDENTIFIER_RE = re.compile(r"^[_a-zA-Z][\w]*$") def to_identifier(alias, quoted=None) -> t.Optional[Identifier]: @@ -3724,6 +3796,8 @@ def convert(value) -> Expression: return Boolean(this=value) if isinstance(value, str): return Literal.string(value) + if isinstance(value, float) and math.isnan(value): + return NULL if isinstance(value, numbers.Number): return Literal.number(value) if isinstance(value, tuple): @@ -3732,11 +3806,13 @@ def convert(value) -> Expression: return Array(expressions=[convert(v) for v in value]) if isinstance(value, dict): return Map( - keys=[convert(k) for k in value.keys()], + keys=[convert(k) for k in value], values=[convert(v) for v in value.values()], ) if isinstance(value, datetime.datetime): - datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z")) + datetime_literal = Literal.string( + (value if value.tzinfo else value.replace(tzinfo=datetime.timezone.utc)).isoformat() + ) return TimeStrToTime(this=datetime_literal) if isinstance(value, datetime.date): date_literal = Literal.string(value.strftime("%Y-%m-%d")) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 2b4c575..0c1578a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -361,10 +361,11 @@ class Generator: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) + exists = "IF NOT EXISTS " if expression.args.get("exists") else "" if not constraints: - return f"{column} {kind}" - return f"{column} {kind} {constraints}" + return f"{exists}{column} {kind}" + return f"{exists}{column} {kind} {constraints}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") @@ -549,6 +550,9 @@ class Generator: text = f"{self.identifier_start}{text}{self.identifier_end}" return text + def national_sql(self, expression: exp.National) -> str: + return f"N{self.sql(expression, 'this')}" + def partition_sql(self, expression: exp.Partition) -> str: keys = csv( *[ @@ -633,6 +637,9 @@ class Generator: def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def pseudotype_sql(self, expression: exp.PseudoType) -> str: + return expression.name.upper() + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" @@ -793,19 +800,17 @@ class Generator: if isinstance(expression.this, exp.Subquery): return f"LATERAL {this}" - alias = expression.args["alias"] - table = alias.name - columns = self.expressions(alias, key="columns", flat=True) - if expression.args.get("view"): - table = f" {table}" if table else table + alias = expression.args["alias"] + columns = self.expressions(alias, key="columns", flat=True) + table = f" {alias.name}" if alias.name else "" columns = f" AS {columns}" if columns else "" op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") return f"{op_sql}{self.sep()}{this}{table}{columns}" - table = f" AS {table}" if table else table - columns = f"({columns})" if columns else "" - return f"LATERAL {this}{table}{columns}" + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + return f"LATERAL {this}{alias}" def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") @@ -891,13 +896,15 @@ class Generator: def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("joins", [])], - *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins") or []], + *[self.sql(sql) for sql in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), self.sql(expression, "qualify"), - self.sql(expression, "window"), + self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) + if expression.args.get("windows") + else "", self.sql(expression, "distribute"), self.sql(expression, "sort"), self.sql(expression, "cluster"), @@ -1008,11 +1015,7 @@ class Generator: spec_sql = " " + self.window_spec_sql(spec) if spec else "" alias = self.sql(expression, "alias") - - if expression.arg_key == "window": - this = this = f"{self.seg('WINDOW')} {this} AS" - else: - this = f"{this} OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}" if not partition and not order and not spec and alias: return f"{this} {alias}" @@ -1141,9 +1144,11 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" - return f"INTERVAL {self.sql(expression, 'this')}{unit}" + return f"INTERVAL{this}{unit}" def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") @@ -1245,6 +1250,43 @@ class Generator: savepoint = f" TO {savepoint}" if savepoint else "" return f"ROLLBACK{savepoint}" + def altercolumn_sql(self, expression: exp.AlterColumn) -> str: + this = self.sql(expression, "this") + + dtype = self.sql(expression, "dtype") + if dtype: + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}" + + default = self.sql(expression, "default") + if default: + return f"ALTER COLUMN {this} SET DEFAULT {default}" + + if not expression.args.get("drop"): + self.unsupported("Unsupported ALTER COLUMN syntax") + + return f"ALTER COLUMN {this} DROP DEFAULT" + + def altertable_sql(self, expression: exp.AlterTable) -> str: + actions = expression.args["actions"] + + if isinstance(actions[0], exp.ColumnDef): + actions = self.expressions(expression, "actions", prefix="ADD COLUMN ") + elif isinstance(actions[0], exp.Schema): + actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") + elif isinstance(actions[0], exp.Drop): + actions = self.expressions(expression, "actions") + elif isinstance(actions[0], exp.AlterColumn): + actions = self.sql(actions[0]) + else: + self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") + + exists = " IF EXISTS" if expression.args.get("exists") else "" + return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1327,6 +1369,9 @@ class Generator: def or_sql(self, expression: exp.Or) -> str: return self.connector_sql(expression, "OR") + def slice_sql(self, expression: exp.Slice) -> str: + return self.binary(expression, ":") + def sub_sql(self, expression: exp.Sub) -> str: return self.binary(expression, "-") @@ -1369,6 +1414,7 @@ class Generator: flat: bool = False, indent: bool = True, sep: str = ", ", + prefix: str = "", ) -> str: expressions = expression.args.get(key or "expressions") @@ -1391,11 +1437,13 @@ class Generator: if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") + result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") + result_sqls.append( + f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}" + ) else: - result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sql, skip_first=False) if indent else result_sql diff --git a/sqlglot/optimizer/canonicalize.py b/sqlglot/optimizer/canonicalize.py index 33529a5..fc37a54 100644 --- a/sqlglot/optimizer/canonicalize.py +++ b/sqlglot/optimizer/canonicalize.py @@ -18,6 +18,9 @@ def canonicalize(expression: exp.Expression) -> exp.Expression: expression = coerce_type(expression) expression = remove_redundant_casts(expression) + if isinstance(expression, exp.Identifier): + expression.set("quoted", True) + return expression diff --git a/sqlglot/optimizer/eliminate_joins.py b/sqlglot/optimizer/eliminate_joins.py index de4e011..3b40710 100644 --- a/sqlglot/optimizer/eliminate_joins.py +++ b/sqlglot/optimizer/eliminate_joins.py @@ -129,10 +129,23 @@ def join_condition(join): """ name = join.this.alias_or_name on = (join.args.get("on") or exp.true()).copy() - on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) source_key = [] join_key = [] + def extract_condition(condition): + left, right = condition.unnest_operands() + left_tables = exp.column_table_names(left) + right_tables = exp.column_table_names(right) + + if name in left_tables and name not in right_tables: + join_key.append(left) + source_key.append(right) + condition.replace(exp.true()) + elif name in right_tables and name not in left_tables: + join_key.append(right) + source_key.append(left) + condition.replace(exp.true()) + # find the join keys # SELECT # FROM x @@ -141,20 +154,30 @@ def join_condition(join): # # should pull y.b as the join key and x.a as the source key if normalized(on): + on = on if isinstance(on, exp.And) else exp.and_(on, exp.true()) + for condition in on.flatten(): if isinstance(condition, exp.EQ): - left, right = condition.unnest_operands() - left_tables = exp.column_table_names(left) - right_tables = exp.column_table_names(right) - - if name in left_tables and name not in right_tables: - join_key.append(left) - source_key.append(right) - condition.replace(exp.true()) - elif name in right_tables and name not in left_tables: - join_key.append(right) - source_key.append(left) - condition.replace(exp.true()) + extract_condition(condition) + elif normalized(on, dnf=True): + conditions = None + + for condition in on.flatten(): + parts = [part for part in condition.flatten() if isinstance(part, exp.EQ)] + if conditions is None: + conditions = parts + else: + temp = [] + for p in parts: + cs = [c for c in conditions if p == c] + + if cs: + temp.append(p) + temp.extend(cs) + conditions = temp + + for condition in conditions: + extract_condition(condition) on = simplify(on) remaining_condition = None if on == exp.true() else on diff --git a/sqlglot/optimizer/eliminate_subqueries.py b/sqlglot/optimizer/eliminate_subqueries.py index 39e252c..2245cc2 100644 --- a/sqlglot/optimizer/eliminate_subqueries.py +++ b/sqlglot/optimizer/eliminate_subqueries.py @@ -58,7 +58,9 @@ def eliminate_subqueries(expression): existing_ctes = {} with_ = root.expression.args.get("with") + recursive = False if with_: + recursive = with_.args.get("recursive") for cte in with_.expressions: existing_ctes[cte.this] = cte.alias new_ctes = [] @@ -88,7 +90,7 @@ def eliminate_subqueries(expression): new_ctes.append(new_cte) if new_ctes: - expression.set("with", exp.With(expressions=new_ctes)) + expression.set("with", exp.With(expressions=new_ctes, recursive=recursive)) return expression diff --git a/sqlglot/optimizer/normalize.py b/sqlglot/optimizer/normalize.py index db538ef..f16f519 100644 --- a/sqlglot/optimizer/normalize.py +++ b/sqlglot/optimizer/normalize.py @@ -69,8 +69,9 @@ def _predicate_lengths(expression, dnf): left, right = expression.args.values() if isinstance(expression, exp.And if dnf else exp.Or): - x = [a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf)] - return x + return [ + a + b for a in _predicate_lengths(left, dnf) for b in _predicate_lengths(right, dnf) + ] return _predicate_lengths(left, dnf) + _predicate_lengths(right, dnf) diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py index 6819717..72e67d4 100644 --- a/sqlglot/optimizer/optimizer.py +++ b/sqlglot/optimizer/optimizer.py @@ -14,7 +14,6 @@ from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.pushdown_projections import pushdown_projections from sqlglot.optimizer.qualify_columns import qualify_columns from sqlglot.optimizer.qualify_tables import qualify_tables -from sqlglot.optimizer.quote_identities import quote_identities from sqlglot.optimizer.unnest_subqueries import unnest_subqueries RULES = ( @@ -34,7 +33,6 @@ RULES = ( eliminate_ctes, annotate_types, canonicalize, - quote_identities, ) diff --git a/sqlglot/optimizer/pushdown_predicates.py b/sqlglot/optimizer/pushdown_predicates.py index f92e5c3..ba5c8b5 100644 --- a/sqlglot/optimizer/pushdown_predicates.py +++ b/sqlglot/optimizer/pushdown_predicates.py @@ -27,7 +27,14 @@ def pushdown_predicates(expression): select = scope.expression where = select.args.get("where") if where: - pushdown(where.this, scope.selected_sources, scope_ref_count) + selected_sources = scope.selected_sources + # a right join can only push down to itself and not the source FROM table + for k, (node, source) in selected_sources.items(): + parent = node.find_ancestor(exp.Join, exp.From) + if isinstance(parent, exp.Join) and parent.side == "RIGHT": + selected_sources = {k: (node, source)} + break + pushdown(where.this, selected_sources, scope_ref_count) # joins should only pushdown into itself, not to other joins # so we limit the selected sources to only itself @@ -148,10 +155,13 @@ def nodes_for_predicate(predicate, sources, scope_ref_count): # a node can reference a CTE which should be pushed down if isinstance(node, exp.From) and not isinstance(source, exp.Table): + with_ = source.parent.expression.args.get("with") + if with_ and with_.recursive: + return {} node = source.expression if isinstance(node, exp.Join): - if node.side: + if node.side and node.side != "RIGHT": return {} nodes[table] = node elif isinstance(node, exp.Select) and len(tables) == 1: diff --git a/sqlglot/optimizer/pushdown_projections.py b/sqlglot/optimizer/pushdown_projections.py index abd9492..49789ac 100644 --- a/sqlglot/optimizer/pushdown_projections.py +++ b/sqlglot/optimizer/pushdown_projections.py @@ -6,7 +6,7 @@ from sqlglot.optimizer.scope import Scope, traverse_scope # Sentinel value that means an outer query selecting ALL columns SELECT_ALL = object() -# SELECTION TO USE IF SELECTION LIST IS EMPTY +# Selection to use if selection list is empty DEFAULT_SELECTION = alias("1", "_") @@ -91,7 +91,7 @@ def _remove_unused_selections(scope, parent_selections): # If there are no remaining selections, just select a single constant if not new_selections: - new_selections.append(DEFAULT_SELECTION) + new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) return removed_indexes @@ -102,5 +102,5 @@ def _remove_indexed_selections(scope, indexes_to_remove): selection for i, selection in enumerate(scope.selects) if i not in indexes_to_remove ] if not new_selections: - new_selections.append(DEFAULT_SELECTION) + new_selections.append(DEFAULT_SELECTION.copy()) scope.expression.set("expressions", new_selections) diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index e6e6dc9..e16a635 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -311,6 +311,9 @@ def _qualify_outputs(scope): alias_ = alias(exp.column(""), alias=selection.name) alias_.set("this", selection) selection = alias_ + elif isinstance(selection, exp.Subquery): + if not selection.alias: + selection.set("alias", exp.TableAlias(this=exp.to_identifier(f"_col_{i}"))) elif not isinstance(selection, exp.Alias): alias_ = alias(exp.column(""), f"_col_{i}") alias_.set("this", selection) diff --git a/sqlglot/optimizer/quote_identities.py b/sqlglot/optimizer/quote_identities.py deleted file mode 100644 index 17623cc..0000000 --- a/sqlglot/optimizer/quote_identities.py +++ /dev/null @@ -1,25 +0,0 @@ -from sqlglot import exp - - -def quote_identities(expression): - """ - Rewrite sqlglot AST to ensure all identities are quoted. - - Example: - >>> import sqlglot - >>> expression = sqlglot.parse_one("SELECT x.a AS a FROM db.x") - >>> quote_identities(expression).sql() - 'SELECT "x"."a" AS "a" FROM "db"."x"' - - Args: - expression (sqlglot.Expression): expression to quote - Returns: - sqlglot.Expression: quoted expression - """ - - def qualify(node): - if isinstance(node, exp.Identifier): - node.set("quoted", True) - return node - - return expression.transform(qualify, copy=False) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 18848f3..6125e4e 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -511,9 +511,20 @@ def _traverse_union(scope): def _traverse_derived_tables(derived_tables, scope, scope_type): sources = {} + is_cte = scope_type == ScopeType.CTE for derived_table in derived_tables: - top = None + recursive_scope = None + + # if the scope is a recursive cte, it must be in the form of + # base_case UNION recursive. thus the recursive scope is the first + # section of the union. + if is_cte and scope.expression.args["with"].recursive: + union = derived_table.this + + if isinstance(union, exp.Union): + recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + for child_scope in _traverse_scope( scope.branch( derived_table if isinstance(derived_table, exp.UDTF) else derived_table.this, @@ -523,16 +534,23 @@ def _traverse_derived_tables(derived_tables, scope, scope_type): ) ): yield child_scope - top = child_scope + # Tables without aliases will be set as "" # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - sources[derived_table.alias] = child_scope - if scope_type == ScopeType.CTE: - scope.cte_scopes.append(top) + alias = derived_table.alias + sources[alias] = child_scope + + if recursive_scope: + child_scope.add_source(alias, recursive_scope) + + # append the final child_scope yielded + if is_cte: + scope.cte_scopes.append(child_scope) else: - scope.derived_table_scopes.append(top) + scope.derived_table_scopes.append(child_scope) + scope.sources.update(sources) diff --git a/sqlglot/optimizer/unnest_subqueries.py b/sqlglot/optimizer/unnest_subqueries.py index 2046917..8d78294 100644 --- a/sqlglot/optimizer/unnest_subqueries.py +++ b/sqlglot/optimizer/unnest_subqueries.py @@ -16,7 +16,7 @@ def unnest_subqueries(expression): >>> expression = sqlglot.parse_one("SELECT * FROM x AS x WHERE (SELECT y.a AS a FROM y AS y WHERE x.a = y.a) = 1 ") >>> unnest_subqueries(expression).sql() 'SELECT * FROM x AS x LEFT JOIN (SELECT y.a AS a FROM y AS y WHERE TRUE GROUP BY y.a)\ - AS "_u_0" ON x.a = "_u_0".a WHERE ("_u_0".a = 1 AND NOT "_u_0".a IS NULL)' + AS _u_0 ON x.a = _u_0.a WHERE (_u_0.a = 1 AND NOT _u_0.a IS NULL)' Args: expression (sqlglot.Expression): expression to unnest @@ -97,8 +97,8 @@ def decorrelate(select, parent_select, external_columns, sequence): table_alias = _alias(sequence) keys = [] - # for all external columns in the where statement, - # split out the relevant data to convert it into a join + # for all external columns in the where statement, find the relevant predicate + # keys to convert it into a join for column in external_columns: if column.find_ancestor(exp.Where) is not where: return @@ -122,6 +122,10 @@ def decorrelate(select, parent_select, external_columns, sequence): if not any(isinstance(predicate, exp.EQ) for *_, predicate in keys): return + is_subquery_projection = any( + node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery) + ) + value = select.selects[0] key_aliases = {} group_by = [] @@ -142,9 +146,14 @@ def decorrelate(select, parent_select, external_columns, sequence): parent_predicate = select.find_ancestor(exp.Predicate) # if the value of the subquery is not an agg or a key, we need to collect it into an array - # so that it can be grouped + # so that it can be grouped. For subquery projections, we use a MAX aggregation instead. + agg_func = exp.Max if is_subquery_projection else exp.ArrayAgg if not value.find(exp.AggFunc) and value.this not in group_by: - select.select(f"ARRAY_AGG({value.this}) AS {value.alias}", append=False, copy=False) + select.select( + exp.alias_(agg_func(this=value.this), value.alias, quoted=False), + append=False, + copy=False, + ) # exists queries should not have any selects as it only checks if there are any rows # all selects will be added by the optimizer and only used for join keys @@ -158,7 +167,7 @@ def decorrelate(select, parent_select, external_columns, sequence): if isinstance(parent_predicate, exp.Exists) or key != value.this: select.select(f"{key} AS {alias}", copy=False) else: - select.select(f"ARRAY_AGG({key}) AS {alias}", copy=False) + select.select(exp.alias_(agg_func(this=key.copy()), alias, quoted=False), copy=False) alias = exp.column(value.alias, table_alias) other = _other_operand(parent_predicate) @@ -186,12 +195,18 @@ def decorrelate(select, parent_select, external_columns, sequence): f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})", ) else: + if is_subquery_projection: + alias = exp.alias_(alias, select.parent.alias) select.parent.replace(alias) for key, column, predicate in keys: predicate.replace(exp.true()) nested = exp.column(key_aliases[key], table_alias) + if is_subquery_projection: + key.replace(nested) + continue + if key in group_by: key.replace(nested) parent_predicate = _replace( diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 29bc9c0..308f363 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors -from sqlglot.helper import apply_index_offset, ensure_collection, seq_get +from sqlglot.helper import apply_index_offset, ensure_collection, ensure_list, seq_get from sqlglot.tokens import Token, Tokenizer, TokenType from sqlglot.trie import in_trie, new_trie @@ -117,6 +117,7 @@ class Parser(metaclass=_Parser): TokenType.GEOMETRY, TokenType.HLLSKETCH, TokenType.HSTORE, + TokenType.PSEUDO_TYPE, TokenType.SUPER, TokenType.SERIAL, TokenType.SMALLSERIAL, @@ -153,6 +154,7 @@ class Parser(metaclass=_Parser): TokenType.CACHE, TokenType.CASCADE, TokenType.COLLATE, + TokenType.COLUMN, TokenType.COMMAND, TokenType.COMMIT, TokenType.COMPOUND, @@ -169,6 +171,7 @@ class Parser(metaclass=_Parser): TokenType.ESCAPE, TokenType.FALSE, TokenType.FIRST, + TokenType.FILTER, TokenType.FOLLOWING, TokenType.FORMAT, TokenType.FUNCTION, @@ -188,6 +191,7 @@ class Parser(metaclass=_Parser): TokenType.MERGE, TokenType.NATURAL, TokenType.NEXT, + TokenType.OFFSET, TokenType.ONLY, TokenType.OPTIONS, TokenType.ORDINALITY, @@ -222,12 +226,18 @@ class Parser(metaclass=_Parser): TokenType.PROPERTIES, TokenType.PROCEDURE, TokenType.VOLATILE, + TokenType.WINDOW, *SUBQUERY_PREDICATES, *TYPE_TOKENS, *NO_PAREN_FUNCTIONS, } - TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - { + TokenType.APPLY, + TokenType.NATURAL, + TokenType.OFFSET, + TokenType.WINDOW, + } UPDATE_ALIAS_TOKENS = TABLE_ALIAS_TOKENS - {TokenType.SET} @@ -257,6 +267,7 @@ class Parser(metaclass=_Parser): TokenType.TABLE, TokenType.TIMESTAMP, TokenType.TIMESTAMPTZ, + TokenType.WINDOW, *TYPE_TOKENS, *SUBQUERY_PREDICATES, } @@ -351,22 +362,27 @@ class Parser(metaclass=_Parser): TokenType.ARROW: lambda self, this, path: self.expression( exp.JSONExtract, this=this, - path=path, + expression=path, ), TokenType.DARROW: lambda self, this, path: self.expression( exp.JSONExtractScalar, this=this, - path=path, + expression=path, ), TokenType.HASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtract, this=this, - path=path, + expression=path, ), TokenType.DHASH_ARROW: lambda self, this, path: self.expression( exp.JSONBExtractScalar, this=this, - path=path, + expression=path, + ), + TokenType.PLACEHOLDER: lambda self, this, key: self.expression( + exp.JSONBContains, + this=this, + expression=key, ), } @@ -392,25 +408,27 @@ class Parser(metaclass=_Parser): exp.Ordered: lambda self: self._parse_ordered(), exp.Having: lambda self: self._parse_having(), exp.With: lambda self: self._parse_with(), + exp.Window: lambda self: self._parse_named_window(), "JOIN_TYPE": lambda self: self._parse_join_side_and_kind(), } STATEMENT_PARSERS = { + TokenType.ALTER: lambda self: self._parse_alter(), + TokenType.BEGIN: lambda self: self._parse_transaction(), + TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), TokenType.CREATE: lambda self: self._parse_create(), + TokenType.DELETE: lambda self: self._parse_delete(), TokenType.DESCRIBE: lambda self: self._parse_describe(), TokenType.DROP: lambda self: self._parse_drop(), + TokenType.END: lambda self: self._parse_commit_or_rollback(), TokenType.INSERT: lambda self: self._parse_insert(), TokenType.LOAD_DATA: lambda self: self._parse_load_data(), - TokenType.UPDATE: lambda self: self._parse_update(), - TokenType.DELETE: lambda self: self._parse_delete(), - TokenType.CACHE: lambda self: self._parse_cache(), + TokenType.MERGE: lambda self: self._parse_merge(), + TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.UPDATE: lambda self: self._parse_update(), TokenType.USE: lambda self: self.expression(exp.Use, this=self._parse_id_var()), - TokenType.BEGIN: lambda self: self._parse_transaction(), - TokenType.COMMIT: lambda self: self._parse_commit_or_rollback(), - TokenType.END: lambda self: self._parse_commit_or_rollback(), - TokenType.ROLLBACK: lambda self: self._parse_commit_or_rollback(), - TokenType.MERGE: lambda self: self._parse_merge(), } UNARY_PARSERS = { @@ -441,6 +459,7 @@ class Parser(metaclass=_Parser): TokenType.HEX_STRING: lambda self, token: self.expression(exp.HexString, this=token.text), TokenType.BYTE_STRING: lambda self, token: self.expression(exp.ByteString, this=token.text), TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token), + TokenType.NATIONAL: lambda self, token: self._parse_national(token), TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(), } @@ -454,6 +473,9 @@ class Parser(metaclass=_Parser): TokenType.ILIKE: lambda self, this: self._parse_escape( self.expression(exp.ILike, this=this, expression=self._parse_bitwise()) ), + TokenType.IRLIKE: lambda self, this: self.expression( + exp.RegexpILike, this=this, expression=self._parse_bitwise() + ), TokenType.RLIKE: lambda self, this: self.expression( exp.RegexpLike, this=this, expression=self._parse_bitwise() ), @@ -535,8 +557,7 @@ class Parser(metaclass=_Parser): "group": lambda self: self._parse_group(), "having": lambda self: self._parse_having(), "qualify": lambda self: self._parse_qualify(), - "window": lambda self: self._match(TokenType.WINDOW) - and self._parse_window(self._parse_id_var(), alias=True), + "windows": lambda self: self._parse_window_clause(), "distribute": lambda self: self._parse_sort(TokenType.DISTRIBUTE_BY, exp.Distribute), "sort": lambda self: self._parse_sort(TokenType.SORT_BY, exp.Sort), "cluster": lambda self: self._parse_sort(TokenType.CLUSTER_BY, exp.Cluster), @@ -551,18 +572,18 @@ class Parser(metaclass=_Parser): MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table) CREATABLES = { - TokenType.TABLE, - TokenType.VIEW, + TokenType.COLUMN, TokenType.FUNCTION, TokenType.INDEX, TokenType.PROCEDURE, TokenType.SCHEMA, + TokenType.TABLE, + TokenType.VIEW, } TRANSACTION_KIND = {"DEFERRED", "IMMEDIATE", "EXCLUSIVE"} STRICT_CAST = True - LATERAL_FUNCTION_AS_VIEW = False __slots__ = ( "error_level", @@ -782,13 +803,16 @@ class Parser(metaclass=_Parser): self._parse_query_modifiers(expression) return expression - def _parse_drop(self): + def _parse_drop(self, default_kind=None): temporary = self._match(TokenType.TEMPORARY) materialized = self._match(TokenType.MATERIALIZED) kind = self._match_set(self.CREATABLES) and self._prev.text if not kind: - self.raise_error(f"Expected {self.CREATABLES}") - return + if default_kind: + kind = default_kind + else: + self.raise_error(f"Expected {self.CREATABLES}") + return return self.expression( exp.Drop, @@ -876,7 +900,7 @@ class Parser(metaclass=_Parser): ) or self._match_pair(TokenType.STRING, TokenType.EQ, advance=False) if assignment: - key = self._parse_var() or self._parse_string() + key = self._parse_var_or_string() self._match(TokenType.EQ) return self.expression(exp.Property, this=key, value=self._parse_column()) @@ -1152,18 +1176,32 @@ class Parser(metaclass=_Parser): elif (table or nested) and self._match(TokenType.L_PAREN): this = self._parse_table() if table else self._parse_select(nested=True) self._parse_query_modifiers(this) + this = self._parse_set_operations(this) self._match_r_paren() - this = self._parse_subquery(this) + # early return so that subquery unions aren't parsed again + # SELECT * FROM (SELECT 1) UNION ALL SELECT 1 + # Union ALL should be a property of the top select node, not the subquery + return self._parse_subquery(this) elif self._match(TokenType.VALUES): + if self._curr.token_type == TokenType.L_PAREN: + # We don't consume the left paren because it's consumed in _parse_value + expressions = self._parse_csv(self._parse_value) + else: + # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows. + # Source: https://prestodb.io/docs/current/sql/values.html + expressions = self._parse_csv( + lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()]) + ) + this = self.expression( exp.Values, - expressions=self._parse_csv(self._parse_value), + expressions=expressions, alias=self._parse_table_alias(), ) else: this = None - return self._parse_set_operations(this) if this else None + return self._parse_set_operations(this) def _parse_with(self, skip_with_token=False): if not skip_with_token and not self._match(TokenType.WITH): @@ -1201,11 +1239,12 @@ class Parser(metaclass=_Parser): alias = self._parse_id_var( any_token=any_token, tokens=alias_tokens or self.TABLE_ALIAS_TOKENS ) - columns = None if self._match(TokenType.L_PAREN): - columns = self._parse_csv(lambda: self._parse_id_var(any_token)) + columns = self._parse_csv(lambda: self._parse_column_def(self._parse_id_var())) self._match_r_paren() + else: + columns = None if not alias and not columns: return None @@ -1295,26 +1334,19 @@ class Parser(metaclass=_Parser): expression=self._parse_function() or self._parse_id_var(any_token=False), ) - columns = None - table_alias = None - if view or self.LATERAL_FUNCTION_AS_VIEW: - table_alias = self._parse_id_var(any_token=False) - if self._match(TokenType.ALIAS): - columns = self._parse_csv(self._parse_id_var) + if view: + table = self._parse_id_var(any_token=False) + columns = self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else [] + table_alias = self.expression(exp.TableAlias, this=table, columns=columns) else: - self._match(TokenType.ALIAS) - table_alias = self._parse_id_var(any_token=False) - - if self._match(TokenType.L_PAREN): - columns = self._parse_csv(self._parse_id_var) - self._match_r_paren() + table_alias = self._parse_table_alias() expression = self.expression( exp.Lateral, this=this, view=view, outer=outer, - alias=self.expression(exp.TableAlias, this=table_alias, columns=columns), + alias=table_alias, ) if outer_apply or cross_apply: @@ -1693,6 +1725,9 @@ class Parser(metaclass=_Parser): if negate: this = self.expression(exp.Not, this=this) + if self._match(TokenType.IS): + this = self._parse_is(this) + return this def _parse_is(self, this): @@ -1796,6 +1831,10 @@ class Parser(metaclass=_Parser): return None type_token = self._prev.token_type + + if type_token == TokenType.PSEUDO_TYPE: + return self.expression(exp.PseudoType, this=self._prev.text) + nested = type_token in self.NESTED_TYPE_TOKENS is_struct = type_token == TokenType.STRUCT expressions = None @@ -1851,6 +1890,8 @@ class Parser(metaclass=_Parser): if value is None: value = exp.DataType(this=exp.DataType.Type.TIMESTAMP, expressions=expressions) + elif type_token == TokenType.INTERVAL: + value = self.expression(exp.Interval, unit=self._parse_var()) if maybe_func and check_func: index2 = self._index @@ -1924,7 +1965,16 @@ class Parser(metaclass=_Parser): def _parse_primary(self): if self._match_set(self.PRIMARY_PARSERS): - return self.PRIMARY_PARSERS[self._prev.token_type](self, self._prev) + token_type = self._prev.token_type + primary = self.PRIMARY_PARSERS[token_type](self, self._prev) + + if token_type == TokenType.STRING: + expressions = [primary] + while self._match(TokenType.STRING): + expressions.append(exp.Literal.string(self._prev.text)) + if len(expressions) > 1: + return self.expression(exp.Concat, expressions=expressions) + return primary if self._match_pair(TokenType.DOT, TokenType.NUMBER): return exp.Literal.number(f"0.{self._prev.text}") @@ -2027,6 +2077,9 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=token.text) + def _parse_national(self, token): + return self.expression(exp.National, this=exp.Literal.string(token.text)) + def _parse_session_parameter(self): kind = None this = self._parse_id_var() or self._parse_primary() @@ -2051,7 +2104,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.L_PAREN): expressions = self._parse_csv(self._parse_id_var) - self._match(TokenType.R_PAREN) + + if not self._match(TokenType.R_PAREN): + self._retreat(index) else: expressions = [self._parse_id_var()] @@ -2065,14 +2120,14 @@ class Parser(metaclass=_Parser): exp.Distinct, expressions=self._parse_csv(self._parse_conjunction) ) else: - this = self._parse_conjunction() + this = self._parse_select_or_expression() if self._match(TokenType.IGNORE_NULLS): this = self.expression(exp.IgnoreNulls, this=this) else: self._match(TokenType.RESPECT_NULLS) - return self._parse_alias(self._parse_limit(self._parse_order(this))) + return self._parse_limit(self._parse_order(this)) def _parse_schema(self, this=None): index = self._index @@ -2081,7 +2136,8 @@ class Parser(metaclass=_Parser): return this args = self._parse_csv( - lambda: self._parse_constraint() or self._parse_column_def(self._parse_field(True)) + lambda: self._parse_constraint() + or self._parse_column_def(self._parse_field(any_token=True)) ) self._match_r_paren() return self.expression(exp.Schema, this=this, expressions=args) @@ -2120,7 +2176,7 @@ class Parser(metaclass=_Parser): elif self._match(TokenType.ENCODE): kind = self.expression(exp.EncodeColumnConstraint, this=self._parse_var()) elif self._match(TokenType.DEFAULT): - kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_conjunction()) + kind = self.expression(exp.DefaultColumnConstraint, this=self._parse_bitwise()) elif self._match_pair(TokenType.NOT, TokenType.NULL): kind = exp.NotNullColumnConstraint() elif self._match(TokenType.NULL): @@ -2211,7 +2267,10 @@ class Parser(metaclass=_Parser): if not self._match(TokenType.L_BRACKET): return this - expressions = self._parse_csv(self._parse_conjunction) + if self._match(TokenType.COLON): + expressions = [self.expression(exp.Slice, expression=self._parse_conjunction())] + else: + expressions = self._parse_csv(lambda: self._parse_slice(self._parse_conjunction())) if not this or this.name.upper() == "ARRAY": this = self.expression(exp.Array, expressions=expressions) @@ -2225,6 +2284,11 @@ class Parser(metaclass=_Parser): this.comments = self._prev_comments return self._parse_bracket(this) + def _parse_slice(self, this): + if self._match(TokenType.COLON): + return self.expression(exp.Slice, this=this, expression=self._parse_conjunction()) + return this + def _parse_case(self): ifs = [] default = None @@ -2386,6 +2450,12 @@ class Parser(metaclass=_Parser): collation=collation, ) + def _parse_window_clause(self): + return self._match(TokenType.WINDOW) and self._parse_csv(self._parse_named_window) + + def _parse_named_window(self): + return self._parse_window(self._parse_id_var(), alias=True) + def _parse_window(self, this, alias=False): if self._match(TokenType.FILTER): where = self._parse_wrapped(self._parse_where) @@ -2501,11 +2571,9 @@ class Parser(metaclass=_Parser): if identifier: return identifier - if any_token and self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: - self._advance() - elif not self._match_set(tokens or self.ID_VAR_TOKENS): - return None - return exp.Identifier(this=self._prev.text, quoted=False) + if (any_token and self._advance_any()) or self._match_set(tokens or self.ID_VAR_TOKENS): + return exp.Identifier(this=self._prev.text, quoted=False) + return None def _parse_string(self): if self._match(TokenType.STRING): @@ -2522,11 +2590,17 @@ class Parser(metaclass=_Parser): return self.expression(exp.Identifier, this=self._prev.text, quoted=True) return self._parse_placeholder() - def _parse_var(self): - if self._match(TokenType.VAR): + def _parse_var(self, any_token=False): + if (any_token and self._advance_any()) or self._match(TokenType.VAR): return self.expression(exp.Var, this=self._prev.text) return self._parse_placeholder() + def _advance_any(self): + if self._curr and self._curr.token_type not in self.RESERVED_KEYWORDS: + self._advance() + return self._prev + return None + def _parse_var_or_string(self): return self._parse_var() or self._parse_string() @@ -2551,8 +2625,9 @@ class Parser(metaclass=_Parser): if self._match(TokenType.PLACEHOLDER): return self.expression(exp.Placeholder) elif self._match(TokenType.COLON): - self._advance() - return self.expression(exp.Placeholder, this=self._prev.text) + if self._match_set((TokenType.NUMBER, TokenType.VAR)): + return self.expression(exp.Placeholder, this=self._prev.text) + self._advance(-1) return None def _parse_except(self): @@ -2647,6 +2722,54 @@ class Parser(metaclass=_Parser): return self.expression(exp.Rollback, savepoint=savepoint) return self.expression(exp.Commit, chain=chain) + def _parse_add_column(self): + if not self._match_text_seq("ADD"): + return None + + self._match(TokenType.COLUMN) + exists_column = self._parse_exists(not_=True) + expression = self._parse_column_def(self._parse_field(any_token=True)) + expression.set("exists", exists_column) + return expression + + def _parse_drop_column(self): + return self._match(TokenType.DROP) and self._parse_drop(default_kind="COLUMN") + + def _parse_alter(self): + if not self._match(TokenType.TABLE): + return None + + exists = self._parse_exists() + this = self._parse_table(schema=True) + + actions = None + if self._match_text_seq("ADD", advance=False): + actions = self._parse_csv(self._parse_add_column) + elif self._match_text_seq("DROP", advance=False): + actions = self._parse_csv(self._parse_drop_column) + elif self._match_text_seq("ALTER"): + self._match(TokenType.COLUMN) + column = self._parse_field(any_token=True) + + if self._match_pair(TokenType.DROP, TokenType.DEFAULT): + actions = self.expression(exp.AlterColumn, this=column, drop=True) + elif self._match_pair(TokenType.SET, TokenType.DEFAULT): + actions = self.expression( + exp.AlterColumn, this=column, default=self._parse_conjunction() + ) + else: + self._match_text_seq("SET", "DATA") + actions = self.expression( + exp.AlterColumn, + this=column, + dtype=self._match_text_seq("TYPE") and self._parse_types(), + collate=self._match(TokenType.COLLATE) and self._parse_term(), + using=self._match(TokenType.USING) and self._parse_conjunction(), + ) + + actions = ensure_list(actions) + return self.expression(exp.AlterTable, this=this, exists=exists, actions=actions) + def _parse_show(self): parser = self._find_parser(self.SHOW_PARSERS, self._show_trie) if parser: @@ -2782,7 +2905,7 @@ class Parser(metaclass=_Parser): return True return False - def _match_text_seq(self, *texts): + def _match_text_seq(self, *texts, advance=True): index = self._index for text in texts: if self._curr and self._curr.text.upper() == text: @@ -2790,6 +2913,10 @@ class Parser(metaclass=_Parser): else: self._retreat(index) return False + + if not advance: + self._retreat(index) + return True def _replace_columns_with_dots(self, this): diff --git a/sqlglot/schema.py b/sqlglot/schema.py index c223ee0..d9a4004 100644 --- a/sqlglot/schema.py +++ b/sqlglot/schema.py @@ -160,9 +160,7 @@ class MappingSchema(AbstractMappingSchema[t.Dict[str, str]], Schema): super().__init__(schema) self.visible = visible or {} self.dialect = dialect - self._type_mapping_cache: t.Dict[str, exp.DataType] = { - "STR": exp.DataType.build("text"), - } + self._type_mapping_cache: t.Dict[str, exp.DataType] = {} @classmethod def from_mapping_schema(cls, mapping_schema: MappingSchema) -> MappingSchema: diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index b25ef8d..0efa7d0 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -48,6 +48,7 @@ class TokenType(AutoName): DOLLAR = auto() PARAMETER = auto() SESSION_PARAMETER = auto() + NATIONAL = auto() BLOCK_START = auto() BLOCK_END = auto() @@ -111,6 +112,7 @@ class TokenType(AutoName): # keywords ALIAS = auto() + ALTER = auto() ALWAYS = auto() ALL = auto() ANTI = auto() @@ -196,6 +198,7 @@ class TokenType(AutoName): INTERVAL = auto() INTO = auto() INTRODUCER = auto() + IRLIKE = auto() IS = auto() ISNULL = auto() JOIN = auto() @@ -241,6 +244,7 @@ class TokenType(AutoName): PRIMARY_KEY = auto() PROCEDURE = auto() PROPERTIES = auto() + PSEUDO_TYPE = auto() QUALIFY = auto() QUOTE = auto() RANGE = auto() @@ -346,7 +350,11 @@ class _Tokenizer(type): def __new__(cls, clsname, bases, attrs): # type: ignore klass = super().__new__(cls, clsname, bases, attrs) - klass._QUOTES = cls._delimeter_list_to_dict(klass.QUOTES) + klass._QUOTES = { + f"{prefix}{s}": e + for s, e in cls._delimeter_list_to_dict(klass.QUOTES).items() + for prefix in (("",) if s[0].isalpha() else ("", "n", "N")) + } klass._BIT_STRINGS = cls._delimeter_list_to_dict(klass.BIT_STRINGS) klass._HEX_STRINGS = cls._delimeter_list_to_dict(klass.HEX_STRINGS) klass._BYTE_STRINGS = cls._delimeter_list_to_dict(klass.BYTE_STRINGS) @@ -470,6 +478,7 @@ class Tokenizer(metaclass=_Tokenizer): "CHECK": TokenType.CHECK, "CLUSTER BY": TokenType.CLUSTER_BY, "COLLATE": TokenType.COLLATE, + "COLUMN": TokenType.COLUMN, "COMMENT": TokenType.SCHEMA_COMMENT, "COMMIT": TokenType.COMMIT, "COMPOUND": TokenType.COMPOUND, @@ -587,6 +596,7 @@ class Tokenizer(metaclass=_Tokenizer): "SEMI": TokenType.SEMI, "SET": TokenType.SET, "SHOW": TokenType.SHOW, + "SIMILAR TO": TokenType.SIMILAR_TO, "SOME": TokenType.SOME, "SORTKEY": TokenType.SORTKEY, "SORT BY": TokenType.SORT_BY, @@ -614,6 +624,7 @@ class Tokenizer(metaclass=_Tokenizer): "VOLATILE": TokenType.VOLATILE, "WHEN": TokenType.WHEN, "WHERE": TokenType.WHERE, + "WINDOW": TokenType.WINDOW, "WITH": TokenType.WITH, "WITH TIME ZONE": TokenType.WITH_TIME_ZONE, "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, @@ -652,6 +663,7 @@ class Tokenizer(metaclass=_Tokenizer): "VARCHAR2": TokenType.VARCHAR, "NVARCHAR": TokenType.NVARCHAR, "NVARCHAR2": TokenType.NVARCHAR, + "STR": TokenType.TEXT, "STRING": TokenType.TEXT, "TEXT": TokenType.TEXT, "CLOB": TokenType.TEXT, @@ -667,7 +679,16 @@ class Tokenizer(metaclass=_Tokenizer): "UNIQUE": TokenType.UNIQUE, "STRUCT": TokenType.STRUCT, "VARIANT": TokenType.VARIANT, - "ALTER": TokenType.COMMAND, + "ALTER": TokenType.ALTER, + "ALTER AGGREGATE": TokenType.COMMAND, + "ALTER DEFAULT": TokenType.COMMAND, + "ALTER DOMAIN": TokenType.COMMAND, + "ALTER ROLE": TokenType.COMMAND, + "ALTER RULE": TokenType.COMMAND, + "ALTER SEQUENCE": TokenType.COMMAND, + "ALTER TYPE": TokenType.COMMAND, + "ALTER USER": TokenType.COMMAND, + "ALTER VIEW": TokenType.COMMAND, "ANALYZE": TokenType.COMMAND, "CALL": TokenType.COMMAND, "EXPLAIN": TokenType.COMMAND, @@ -967,7 +988,7 @@ class Tokenizer(metaclass=_Tokenizer): text = self._extract_string(quote_end) text = text.encode(self.ENCODE).decode(self.ENCODE) if self.ENCODE else text # type: ignore text = text.replace("\\\\", "\\") if self._replace_backslash else text - self._add(TokenType.STRING, text) + self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) return True # X'1234, b'0110', E'\\\\\' etc. diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index da18502..665cc91 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -150,8 +150,8 @@ class TestDataframeColumn(unittest.TestCase): F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), ) self.assertEqual( - "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " - "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", + "cola BETWEEN CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP) " + "AND CAST('2022-03-01T01:01:01+00:00' AS TIMESTAMP)", F.col("cola") .between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)) .sql(), diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 99b140d..37ea2e1 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -30,7 +30,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.lit(datetime.date(2022, 1, 1)) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.lit({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) @@ -52,7 +52,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.col(datetime.date(2022, 1, 1)) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01T01:01:01+00:00' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.col({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 1d60ec6..258e47f 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -318,3 +318,9 @@ class TestBigQuery(Validator): self.validate_identity( "CREATE TABLE FUNCTION a(x INT64) RETURNS TABLE <q STRING, r INT64> AS SELECT s, t" ) + + def test_group_concat(self): + self.validate_all( + "SELECT a, GROUP_CONCAT(b) FROM table GROUP BY a", + write={"bigquery": "SELECT a, STRING_AGG(b) FROM table GROUP BY a"}, + ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 2168f55..7560d61 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -12,6 +12,76 @@ class TestDatabricks(Validator): "databricks": "SELECT DATEDIFF(year, 'start', 'end')", }, ) + self.validate_all( + "SELECT DATEDIFF(microsecond, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(microsecond, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) * 1000000 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(millisecond, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(millisecond, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) * 1000 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(second, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(second, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(minute, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(minute, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 60 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(hour, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(hour, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 3600 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(day, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(day, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(epoch FROM CAST('end' AS TIMESTAMP) - CAST('start' AS TIMESTAMP)) / 86400 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(week, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(week, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 48 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(day FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 7 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(month, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(month, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 12 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(quarter, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) * 4 + EXTRACT(month FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) / 3 AS BIGINT)", + }, + ) + self.validate_all( + "SELECT DATEDIFF(year, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(year, 'start', 'end')", + "postgres": "SELECT CAST(EXTRACT(year FROM AGE(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP))) AS BIGINT)", + }, + ) def test_add_date(self): self.validate_all( diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index ee67bf1..ced7102 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -333,7 +333,7 @@ class TestDialect(Validator): "drill": "CAST('2020-01-01' AS DATE)", "duckdb": "CAST('2020-01-01' AS DATE)", "hive": "TO_DATE('2020-01-01')", - "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + "presto": "CAST('2020-01-01' AS TIMESTAMP)", "starrocks": "TO_DATE('2020-01-01')", }, ) @@ -343,7 +343,7 @@ class TestDialect(Validator): "drill": "CAST('2020-01-01' AS TIMESTAMP)", "duckdb": "CAST('2020-01-01' AS TIMESTAMP)", "hive": "CAST('2020-01-01' AS TIMESTAMP)", - "presto": "DATE_PARSE('2020-01-01', '%Y-%m-%d %H:%i:%s')", + "presto": "CAST('2020-01-01' AS TIMESTAMP)", }, ) self.validate_all( @@ -723,23 +723,23 @@ class TestDialect(Validator): read={ "postgres": "x->'y'", "presto": "JSON_EXTRACT(x, 'y')", - "starrocks": "x->'y'", + "starrocks": "x -> 'y'", }, write={ "oracle": "JSON_EXTRACT(x, 'y')", - "postgres": "x->'y'", + "postgres": "x -> 'y'", "presto": "JSON_EXTRACT(x, 'y')", - "starrocks": "x->'y'", + "starrocks": "x -> 'y'", }, ) self.validate_all( "JSON_EXTRACT_SCALAR(x, 'y')", read={ - "postgres": "x->>'y'", + "postgres": "x ->> 'y'", "presto": "JSON_EXTRACT_SCALAR(x, 'y')", }, write={ - "postgres": "x->>'y'", + "postgres": "x ->> 'y'", "presto": "JSON_EXTRACT_SCALAR(x, 'y')", }, ) @@ -749,7 +749,7 @@ class TestDialect(Validator): "postgres": "x#>'y'", }, write={ - "postgres": "x#>'y'", + "postgres": "x #> 'y'", }, ) self.validate_all( @@ -758,7 +758,7 @@ class TestDialect(Validator): "postgres": "x#>>'y'", }, write={ - "postgres": "x#>>'y'", + "postgres": "x #>> 'y'", }, ) diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 99b0493..a37062c 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -59,7 +59,7 @@ class TestDuckDB(Validator): "TO_TIMESTAMP(x)", write={ "duckdb": "CAST(x AS TIMESTAMP)", - "presto": "DATE_PARSE(x, '%Y-%m-%d %H:%i:%s')", + "presto": "CAST(x AS TIMESTAMP)", "hive": "CAST(x AS TIMESTAMP)", }, ) @@ -302,3 +302,20 @@ class TestDuckDB(Validator): read="duckdb", unsupported_level=ErrorLevel.IMMEDIATE, ) + + def test_array(self): + self.validate_identity("ARRAY(SELECT id FROM t)") + + def test_cast(self): + self.validate_all( + "123::CHARACTER VARYING", + write={ + "duckdb": "CAST(123 AS TEXT)", + }, + ) + + def test_bool_or(self): + self.validate_all( + "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a", + write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"}, + ) diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py index 5ac8714..a7f3b8f 100644 --- a/tests/dialects/test_hive.py +++ b/tests/dialects/test_hive.py @@ -268,10 +268,10 @@ class TestHive(Validator): self.validate_all( "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", write={ - "duckdb": "STRFTIME('2020-01-01', '%Y-%m-%d %H:%M:%S')", - "presto": "DATE_FORMAT('2020-01-01', '%Y-%m-%d %H:%i:%S')", - "hive": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", - "spark": "DATE_FORMAT('2020-01-01', 'yyyy-MM-dd HH:mm:ss')", + "duckdb": "STRFTIME(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S')", + "presto": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), '%Y-%m-%d %H:%i:%S')", + "hive": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", + "spark": "DATE_FORMAT(CAST('2020-01-01' AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss')", }, ) self.validate_all( diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 5064dbe..7cd686d 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -91,12 +91,12 @@ class TestMySQL(Validator): }, ) self.validate_all( - "N 'some text'", + "N'some text'", read={ - "mysql": "N'some text'", + "mysql": "n'some text'", }, write={ - "mysql": "N 'some text'", + "mysql": "N'some text'", }, ) self.validate_all( diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 962b28b..1e048d5 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -3,6 +3,7 @@ from tests.dialects.test_dialect import Validator class TestPostgres(Validator): + maxDiff = None dialect = "postgres" def test_ddl(self): @@ -94,6 +95,7 @@ class TestPostgres(Validator): self.validate_identity("COMMENT ON TABLE mytable IS 'this'") self.validate_identity("SELECT e'\\xDEADBEEF'") self.validate_identity("SELECT CAST(e'\\176' AS BYTEA)") + self.validate_identity("""SELECT * FROM JSON_TO_RECORDSET(z) AS y("rank" INT)""") self.validate_all( "END WORK AND NO CHAIN", @@ -112,6 +114,14 @@ class TestPostgres(Validator): "spark": "CREATE TABLE x (a UUID, b BINARY)", }, ) + self.validate_all( + "123::CHARACTER VARYING", + write={"postgres": "CAST(123 AS VARCHAR)"}, + ) + self.validate_all( + "TO_TIMESTAMP(123::DOUBLE PRECISION)", + write={"postgres": "TO_TIMESTAMP(CAST(123 AS DOUBLE PRECISION))"}, + ) self.validate_identity( "CREATE TABLE A (LIKE B INCLUDING CONSTRAINT INCLUDING COMPRESSION EXCLUDING COMMENTS)" @@ -193,15 +203,21 @@ class TestPostgres(Validator): }, ) self.validate_all( - "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", - read={ + "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) pname ON TRUE WHERE pname IS NULL", + write={ "postgres": "SELECT m.name FROM manufacturers AS m LEFT JOIN LATERAL GET_PRODUCT_NAMES(m.id) AS pname ON TRUE WHERE pname IS NULL", }, ) self.validate_all( "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", - read={ - "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons p1, polygons p2, LATERAL VERTICES(p1.poly) v1, LATERAL VERTICES(p2.poly) v2 WHERE (v1 <-> v2) < 10 AND p1.id != p2.id", + write={ + "postgres": "SELECT p1.id, p2.id, v1, v2 FROM polygons AS p1, polygons AS p2, LATERAL VERTICES(p1.poly) AS v1, LATERAL VERTICES(p2.poly) AS v2 WHERE (v1 <-> v2) < 10 AND p1.id <> p2.id", + }, + ) + self.validate_all( + "SELECT * FROM r CROSS JOIN LATERAL unnest(array(1)) AS s(location)", + write={ + "postgres": "SELECT * FROM r CROSS JOIN LATERAL UNNEST(ARRAY[1]) AS s(location)", }, ) self.validate_all( @@ -218,35 +234,46 @@ class TestPostgres(Validator): ) self.validate_all( "'[1,2,3]'::json->2", - write={"postgres": "CAST('[1,2,3]' AS JSON)->'2'"}, + write={"postgres": "CAST('[1,2,3]' AS JSON) -> '2'"}, ) self.validate_all( """'{"a":1,"b":2}'::json->'b'""", - write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->'b'"""}, + write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) -> 'b'"""}, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'->'y'""", - write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON)->'x'->'y'"""}, + write={"postgres": """CAST('{"x": {"y": 1}}' AS JSON) -> 'x' -> 'y'"""}, ) self.validate_all( """'{"x": {"y": 1}}'::json->'x'::json->'y'""", - write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON)->'x' AS JSON)->'y'"""}, + write={"postgres": """CAST(CAST('{"x": {"y": 1}}' AS JSON) -> 'x' AS JSON) -> 'y'"""}, ) self.validate_all( """'[1,2,3]'::json->>2""", - write={"postgres": "CAST('[1,2,3]' AS JSON)->>'2'"}, + write={"postgres": "CAST('[1,2,3]' AS JSON) ->> '2'"}, ) self.validate_all( """'{"a":1,"b":2}'::json->>'b'""", - write={"postgres": """CAST('{"a":1,"b":2}' AS JSON)->>'b'"""}, + write={"postgres": """CAST('{"a":1,"b":2}' AS JSON) ->> 'b'"""}, ) self.validate_all( """'{"a":[1,2,3],"b":[4,5,6]}'::json#>'{a,2}'""", - write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>'{a,2}'"""}, + write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #> '{a,2}'"""}, ) self.validate_all( """'{"a":[1,2,3],"b":[4,5,6]}'::json#>>'{a,2}'""", - write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON)#>>'{a,2}'"""}, + write={"postgres": """CAST('{"a":[1,2,3],"b":[4,5,6]}' AS JSON) #>> '{a,2}'"""}, + ) + self.validate_all( + """SELECT JSON_ARRAY_ELEMENTS((foo->'sections')::JSON) AS sections""", + write={ + "postgres": """SELECT JSON_ARRAY_ELEMENTS(CAST((foo -> 'sections') AS JSON)) AS sections""", + "presto": """SELECT JSON_ARRAY_ELEMENTS(CAST((JSON_EXTRACT(foo, 'sections')) AS JSON)) AS sections""", + }, + ) + self.validate_all( + """x ? 'x'""", + write={"postgres": "x ? 'x'"}, ) self.validate_all( "SELECT $$a$$", @@ -260,3 +287,49 @@ class TestPostgres(Validator): "UPDATE MYTABLE T1 SET T1.COL = 13", write={"postgres": "UPDATE MYTABLE AS T1 SET T1.COL = 13"}, ) + + self.validate_identity("x ~ 'y'") + self.validate_identity("x ~* 'y'") + self.validate_all( + "x !~ 'y'", + write={"postgres": "NOT x ~ 'y'"}, + ) + self.validate_all( + "x !~* 'y'", + write={"postgres": "NOT x ~* 'y'"}, + ) + + self.validate_all( + "x ~~ 'y'", + write={"postgres": "x LIKE 'y'"}, + ) + self.validate_all( + "x ~~* 'y'", + write={"postgres": "x ILIKE 'y'"}, + ) + self.validate_all( + "x !~~ 'y'", + write={"postgres": "NOT x LIKE 'y'"}, + ) + self.validate_all( + "x !~~* 'y'", + write={"postgres": "NOT x ILIKE 'y'"}, + ) + self.validate_all( + "'45 days'::interval day", + write={"postgres": "CAST('45 days' AS INTERVAL day)"}, + ) + self.validate_all( + "'x' 'y' 'z'", + write={"postgres": "CONCAT('x', 'y', 'z')"}, + ) + self.validate_identity("SELECT ARRAY(SELECT 1)") + + self.validate_all( + "x::cstring", + write={"postgres": "CAST(x AS CSTRING)"}, + ) + + self.validate_identity( + "SELECT SUM(x) OVER a, SUM(y) OVER b FROM c WINDOW a AS (PARTITION BY d), b AS (PARTITION BY e)" + ) diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 3034df5..f650c98 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -53,7 +53,7 @@ class TestRedshift(Validator): self.validate_all( "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", write={ - "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + "redshift": 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', }, ) self.validate_all( diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index e3d0cff..df62c6c 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -7,6 +7,12 @@ class TestSnowflake(Validator): def test_snowflake(self): self.validate_all( + "SELECT * FROM xxx WHERE col ilike '%Don''t%'", + write={ + "snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'", + }, + ) + self.validate_all( 'x:a:"b c"', write={ "duckdb": "x['a']['b c']", @@ -509,3 +515,11 @@ FROM persons AS p, LATERAL FLATTEN(input => p.c, path => 'contact') AS f, LATERA "snowflake": "SELECT 1 MINUS SELECT 1", }, ) + + def test_values(self): + self.validate_all( + 'SELECT c0, c1 FROM (VALUES (1, 2), (3, 4)) AS "t0"(c0, c1)', + read={ + "spark": "SELECT `c0`, `c1` FROM (VALUES (1, 2), (3, 4)) AS `t0`(`c0`, `c1`)", + }, + ) diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py index 3a9f918..7395e72 100644 --- a/tests/dialects/test_spark.py +++ b/tests/dialects/test_spark.py @@ -101,6 +101,18 @@ TBLPROPERTIES ( "spark": "CACHE TABLE testCache OPTIONS('storageLevel' = 'DISK_ONLY') AS SELECT * FROM testData" }, ) + self.validate_all( + "ALTER TABLE StudentInfo ADD COLUMNS (LastName STRING, DOB TIMESTAMP)", + write={ + "spark": "ALTER TABLE StudentInfo ADD COLUMNS (LastName STRING, DOB TIMESTAMP)", + }, + ) + self.validate_all( + "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)", + write={ + "spark": "ALTER TABLE StudentInfo DROP COLUMNS (LastName, DOB)", + }, + ) def test_to_date(self): self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index e4c6e60..b4ac094 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -431,11 +431,11 @@ class TestTSQL(Validator): def test_string(self): self.validate_all( "SELECT N'test'", - write={"spark": "SELECT 'test'"}, + write={"spark": "SELECT N'test'"}, ) self.validate_all( "SELECT n'test'", - write={"spark": "SELECT 'test'"}, + write={"spark": "SELECT N'test'"}, ) self.validate_all( "SELECT '''test'''", diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql index e12b673..e6a6e6b 100644 --- a/tests/fixtures/identity.sql +++ b/tests/fixtures/identity.sql @@ -17,6 +17,7 @@ SUM(CASE WHEN x > 1 THEN 1 ELSE 0 END) / y '\x' "x" "" +N'abc' x x % 1 x < 1 @@ -33,6 +34,10 @@ x << 1 x >> 1 x >> 1 | 1 & 1 ^ 1 x || y +x[ : ] +x[1 : ] +x[1 : 2] +x[-4 : -1] 1 - -1 - -5 dec.x + y @@ -62,6 +67,8 @@ x BETWEEN 'a' || b AND 'c' || d NOT x IS NULL x IS TRUE x IS FALSE +x IS TRUE IS TRUE +x LIKE y IS TRUE time zone ARRAY<TEXT> @@ -93,10 +100,11 @@ x LIKE '%y%' ESCAPE '\' x ILIKE '%y%' ESCAPE '\' 1 AS escape INTERVAL '1' day -INTERVAL '1' month +INTERVAL '1' MONTH INTERVAL '1 day' INTERVAL 2 months -INTERVAL 1 + 3 days +INTERVAL 1 + 3 DAYS +CAST('45' AS INTERVAL DAYS) TIMESTAMP_DIFF(CURRENT_TIMESTAMP(), 1, DAY) DATETIME_DIFF(CURRENT_DATE, 1, DAY) QUANTILE(x, 0.5) @@ -144,6 +152,7 @@ SELECT 1 AS count FROM test SELECT 1 AS comment FROM test SELECT 1 AS numeric FROM test SELECT 1 AS number FROM test +SELECT COALESCE(offset, 1) SELECT t.count SELECT DISTINCT x FROM test SELECT DISTINCT x, y FROM test @@ -196,6 +205,7 @@ SELECT JSON_EXTRACT_SCALAR(x, '$.name') SELECT x LIKE '%x%' FROM test SELECT * FROM test LIMIT 100 SELECT * FROM test LIMIT 100 OFFSET 200 +SELECT * FROM test FETCH FIRST ROWS ONLY SELECT * FROM test FETCH FIRST 1 ROWS ONLY SELECT * FROM test FETCH NEXT 1 ROWS ONLY SELECT (1 > 2) AS x FROM test @@ -460,6 +470,7 @@ CREATE TABLE z (end INT) CREATE TABLE z (a ARRAY<TEXT>, b MAP<TEXT, DOUBLE>, c DECIMAL(5, 3)) CREATE TABLE z (a INT, b VARCHAR COMMENT 'z', c VARCHAR(100) COMMENT 'z', d DECIMAL(5, 3)) CREATE TABLE z (a INT(11) DEFAULT UUID()) +CREATE TABLE z (n INT DEFAULT 0 NOT NULL) CREATE TABLE z (a INT(11) DEFAULT NULL COMMENT '客户id') CREATE TABLE z (a INT(11) NOT NULL DEFAULT 1) CREATE TABLE z (a INT(11) NOT NULL DEFAULT -1) @@ -511,7 +522,13 @@ INSERT OVERWRITE TABLE a.b PARTITION(ds) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD') SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds, hour) SELECT x FROM y INSERT OVERWRITE TABLE a.b PARTITION(ds='YYYY-MM-DD', hour='hh') SELECT x FROM y +ALTER AGGREGATE bla(foo) OWNER TO CURRENT_USER +ALTER RULE foo ON bla RENAME TO baz +ALTER ROLE CURRENT_USER WITH REPLICATION +ALTER SEQUENCE IF EXISTS baz RESTART WITH boo ALTER TYPE electronic_mail RENAME TO email +ALTER VIEW foo ALTER COLUMN bla SET DEFAULT 'NOT SET' +ALTER DOMAIN foo VALIDATE CONSTRAINT bla ANALYZE a.y DELETE FROM x WHERE y > 1 DELETE FROM y @@ -596,3 +613,17 @@ SELECT x AS INTO FROM bla SELECT * INTO newevent FROM event SELECT * INTO TEMPORARY newevent FROM event SELECT * INTO UNLOGGED newevent FROM event +ALTER TABLE integers ADD COLUMN k INT +ALTER TABLE integers ADD COLUMN IF NOT EXISTS k INT +ALTER TABLE IF EXISTS integers ADD COLUMN k INT +ALTER TABLE integers ADD COLUMN l INT DEFAULT 10 +ALTER TABLE measurements ADD COLUMN mtime TIMESTAMPTZ DEFAULT NOW() +ALTER TABLE integers DROP COLUMN k +ALTER TABLE integers DROP COLUMN IF EXISTS k +ALTER TABLE integers DROP COLUMN k CASCADE +ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR +ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR USING CONCAT(i, '_', j) +ALTER TABLE integers ALTER COLUMN i SET DEFAULT 10 +ALTER TABLE integers ALTER COLUMN i DROP DEFAULT +ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B +ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT diff --git a/tests/fixtures/optimizer/canonicalize.sql b/tests/fixtures/optimizer/canonicalize.sql index 8880881..8c7cd45 100644 --- a/tests/fixtures/optimizer/canonicalize.sql +++ b/tests/fixtures/optimizer/canonicalize.sql @@ -1,11 +1,11 @@ SELECT w.d + w.e AS c FROM w AS w; -SELECT CONCAT(w.d, w.e) AS c FROM w AS w; +SELECT CONCAT("w"."d", "w"."e") AS "c" FROM "w" AS "w"; SELECT CAST(w.d AS DATE) > w.e AS a FROM w AS w; -SELECT CAST(w.d AS DATE) > CAST(w.e AS DATE) AS a FROM w AS w; +SELECT CAST("w"."d" AS DATE) > CAST("w"."e" AS DATE) AS "a" FROM "w" AS "w"; SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; -SELECT CAST(1 AS VARCHAR) AS a FROM w AS w; +SELECT CAST(1 AS VARCHAR) AS "a" FROM "w" AS "w"; SELECT CAST(1 + 3.2 AS DOUBLE) AS a FROM w AS w; -SELECT 1 + 3.2 AS a FROM w AS w; +SELECT 1 + 3.2 AS "a" FROM "w" AS "w"; diff --git a/tests/fixtures/optimizer/optimizer.sql b/tests/fixtures/optimizer/optimizer.sql index a692c7d..b502d81 100644 --- a/tests/fixtures/optimizer/optimizer.sql +++ b/tests/fixtures/optimizer/optimizer.sql @@ -291,3 +291,81 @@ SELECT a1 FROM cte1; SELECT "x"."a" AS "a1" FROM "x" AS "x"; + +# title: recursive cte +WITH RECURSIVE cte1 AS ( + SELECT * + FROM ( + SELECT 1 AS a, 2 AS b + ) base + CROSS JOIN (SELECT 3 c) y + UNION ALL + SELECT * + FROM cte1 + WHERE a < 1 +) +SELECT * +FROM cte1; +WITH RECURSIVE "base" AS ( + SELECT + 1 AS "a", + 2 AS "b" +), "y" AS ( + SELECT + 3 AS "c" +), "cte1" AS ( + SELECT + "base"."a" AS "a", + "base"."b" AS "b", + "y"."c" AS "c" + FROM "base" AS "base" + CROSS JOIN "y" AS "y" + UNION ALL + SELECT + "cte1"."a" AS "a", + "cte1"."b" AS "b", + "cte1"."c" AS "c" + FROM "cte1" + WHERE + "cte1"."a" < 1 +) +SELECT + "cte1"."a" AS "a", + "cte1"."b" AS "b", + "cte1"."c" AS "c" +FROM "cte1"; + +# title: right join should not push down to from +SELECT x.a, y.b +FROM x +RIGHT JOIN y +ON x.a = y.b +WHERE x.b = 1; +SELECT + "x"."a" AS "a", + "y"."b" AS "b" +FROM "x" AS "x" +RIGHT JOIN "y" AS "y" + ON "x"."a" = "y"."b" +WHERE + "x"."b" = 1; + +# title: right join can push down to itself +SELECT x.a, y.b +FROM x +RIGHT JOIN y +ON x.a = y.b +WHERE y.b = 1; +WITH "y_2" AS ( + SELECT + "y"."b" AS "b" + FROM "y" AS "y" + WHERE + "y"."b" = 1 +) +SELECT + "x"."a" AS "a", + "y"."b" AS "b" +FROM "x" AS "x" +RIGHT JOIN "y_2" AS "y" + ON "x"."a" = "y"."b"; diff --git a/tests/fixtures/optimizer/pushdown_projections.sql b/tests/fixtures/optimizer/pushdown_projections.sql index ba4bf45..2a21f65 100644 --- a/tests/fixtures/optimizer/pushdown_projections.sql +++ b/tests/fixtures/optimizer/pushdown_projections.sql @@ -1,32 +1,32 @@ SELECT a FROM (SELECT * FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT 1 FROM (SELECT * FROM x) WHERE b = 2; -SELECT 1 AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS "_q_0" WHERE "_q_0".b = 2; +SELECT 1 AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS _q_0 WHERE _q_0.b = 2; SELECT (SELECT c FROM y WHERE q.b = y.b) FROM (SELECT * FROM x) AS q; -SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS "_col_0" FROM (SELECT x.b AS b FROM x AS x) AS q; +SELECT (SELECT y.c AS c FROM y AS y WHERE q.b = y.b) AS _col_0 FROM (SELECT x.b AS b FROM x AS x) AS q; SELECT a FROM x JOIN (SELECT b, c FROM y) AS z ON x.b = z.b; SELECT x.a AS a FROM x AS x JOIN (SELECT y.b AS b FROM y AS y) AS z ON x.b = z.b; SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; -SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; +SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS _ FROM x AS x) AS x2; SELECT x1.a FROM (SELECT * FROM x) AS x1, (SELECT * FROM x) AS x2; -SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS "_" FROM x AS x) AS x2; +SELECT x1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS x1, (SELECT 1 AS _ FROM x AS x) AS x2; SELECT a FROM (SELECT DISTINCT a, b FROM x); -SELECT "_q_0".a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT DISTINCT x.a AS a, x.b AS b FROM x AS x) AS _q_0; SELECT a FROM (SELECT a, b FROM x UNION ALL SELECT a, b FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION ALL SELECT x.a AS a FROM x AS x) AS _q_0; WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x UNION ALL SELECT z.b AS b, z.c AS c FROM z) SELECT a, b FROM t1; WITH t1 AS (SELECT x.a AS a, x.b AS b FROM x AS x UNION ALL SELECT z.b AS b, z.c AS c FROM z AS z) SELECT t1.a AS a, t1.b AS b FROM t1; SELECT a FROM (SELECT a, b FROM x UNION SELECT a, b FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x UNION SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0; WITH y AS (SELECT * FROM x) SELECT a FROM y; WITH y AS (SELECT x.a AS a FROM x AS x) SELECT y.a AS a FROM y; @@ -38,10 +38,10 @@ WITH z AS (SELECT * FROM x) SELECT a FROM z UNION SELECT a FROM z; WITH z AS (SELECT x.a AS a FROM x AS x) SELECT z.a AS a FROM z UNION SELECT z.a AS a FROM z; SELECT b FROM (SELECT a, SUM(b) AS b FROM x GROUP BY a); -SELECT "_q_0".b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS "_q_0"; +SELECT _q_0.b AS b FROM (SELECT SUM(x.b) AS b FROM x AS x GROUP BY x.a) AS _q_0; SELECT b FROM (SELECT a, SUM(b) AS b FROM x ORDER BY a); -SELECT "_q_0".b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS "_q_0"; +SELECT _q_0.b AS b FROM (SELECT x.a AS a, SUM(x.b) AS b FROM x AS x ORDER BY a) AS _q_0; SELECT x FROM (VALUES(1, 2)) AS q(x, y); SELECT q.x AS x FROM (VALUES (1, 2)) AS q(x, y); diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 1176078..9c5a0be 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -21,15 +21,15 @@ SELECT x.a AS b FROM x AS x; # execute: false SELECT 1, 2 FROM x; -SELECT 1 AS "_col_0", 2 AS "_col_1" FROM x AS x; +SELECT 1 AS _col_0, 2 AS _col_1 FROM x AS x; # execute: false SELECT a + b FROM x; -SELECT x.a + x.b AS "_col_0" FROM x AS x; +SELECT x.a + x.b AS _col_0 FROM x AS x; # execute: false SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; -SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; +SELECT x.a AS a, SUM(x.b) AS _col_1 FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; SELECT SUM(a) AS c FROM x HAVING SUM(a) > 3; SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3; @@ -59,7 +59,7 @@ SELECT SUM(x.a) AS c, SUM(x.b) AS d FROM x AS x ORDER BY SUM(x.a), SUM(x.b); # execute: false SELECT SUM(a), SUM(b) AS c FROM x ORDER BY 1, 2; -SELECT SUM(x.a) AS "_col_0", SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); +SELECT SUM(x.a) AS _col_0, SUM(x.b) AS c FROM x AS x ORDER BY SUM(x.a), SUM(x.b); SELECT a AS j, b FROM x GROUP BY j, b; SELECT x.a AS j, x.b AS b FROM x AS x GROUP BY x.a, x.b; @@ -72,7 +72,7 @@ SELECT x.a AS a, x.b AS b FROM x AS x ORDER BY x.a, x.b; # execute: false SELECT DATE(a), DATE(b) AS c FROM x GROUP BY 1, 2; -SELECT DATE(x.a) AS "_col_0", DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); +SELECT DATE(x.a) AS _col_0, DATE(x.b) AS c FROM x AS x GROUP BY DATE(x.a), DATE(x.b); SELECT SUM(x.a) AS c FROM x JOIN y ON x.b = y.b GROUP BY c; SELECT SUM(x.a) AS c FROM x AS x JOIN y AS y ON x.b = y.b GROUP BY y.c; @@ -130,10 +130,10 @@ SELECT a FROM (SELECT a FROM x AS x) y; SELECT y.a AS a FROM (SELECT x.a AS a FROM x AS x) AS y; SELECT a FROM (SELECT a AS a FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT a FROM (SELECT a FROM (SELECT a FROM x)); -SELECT "_q_1".a AS a FROM (SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0") AS "_q_1"; +SELECT _q_1.a AS a FROM (SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0) AS _q_1; SELECT x.a FROM x AS x JOIN (SELECT * FROM x) AS y ON x.a = y.a; SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a, x.b AS b FROM x AS x) AS y ON x.a = y.a; @@ -157,7 +157,7 @@ SELECT a FROM x UNION SELECT a FROM x UNION SELECT a FROM x; SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x; SELECT a FROM (SELECT a FROM x UNION SELECT a FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x UNION SELECT x.a AS a FROM x AS x) AS _q_0; -------------------------------------- -- Subqueries @@ -167,10 +167,10 @@ SELECT x.a AS a FROM x AS x WHERE x.b IN (SELECT y.c AS c FROM y AS y); # execute: false SELECT (SELECT c FROM y) FROM x; -SELECT (SELECT y.c AS c FROM y AS y) AS "_col_0" FROM x AS x; +SELECT (SELECT y.c AS c FROM y AS y) AS _col_0 FROM x AS x; SELECT a FROM (SELECT a FROM x) WHERE a IN (SELECT b FROM (SELECT b FROM y)); -SELECT "_q_1".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_1" WHERE "_q_1".a IN (SELECT "_q_0".b AS b FROM (SELECT y.b AS b FROM y AS y) AS "_q_0"); +SELECT _q_1.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_1 WHERE _q_1.a IN (SELECT _q_0.b AS b FROM (SELECT y.b AS b FROM y AS y) AS _q_0); -------------------------------------- -- Correlated subqueries @@ -215,10 +215,10 @@ SELECT x.*, y.* FROM x JOIN y ON x.b = y.b; SELECT x.a AS a, x.b AS b, y.b AS b, y.c AS c FROM x AS x JOIN y AS y ON x.b = y.b; SELECT a FROM (SELECT * FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a, x.b AS b FROM x AS x) AS _q_0; SELECT * FROM (SELECT a FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; -------------------------------------- -- CTEs diff --git a/tests/fixtures/optimizer/qualify_columns__with_invisible.sql b/tests/fixtures/optimizer/qualify_columns__with_invisible.sql index ee46c23..05253f3 100644 --- a/tests/fixtures/optimizer/qualify_columns__with_invisible.sql +++ b/tests/fixtures/optimizer/qualify_columns__with_invisible.sql @@ -11,10 +11,10 @@ SELECT x.b AS b FROM x AS x; -- Derived tables -------------------------------------- SELECT x.a FROM x AS x JOIN (SELECT * FROM x); -SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT x.a AS a FROM x AS x JOIN (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT x.b FROM x AS x JOIN (SELECT b FROM x); -SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS "_q_0"; +SELECT x.b AS b FROM x AS x JOIN (SELECT x.b AS b FROM x AS x) AS _q_0; -------------------------------------- -- Expand * @@ -29,7 +29,7 @@ SELECT * FROM y JOIN z ON y.c = z.c; SELECT y.b AS b, z.b AS b FROM y AS y JOIN z AS z ON y.c = z.c; SELECT a FROM (SELECT * FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; SELECT * FROM (SELECT a FROM x); -SELECT "_q_0".a AS a FROM (SELECT x.a AS a FROM x AS x) AS "_q_0"; +SELECT _q_0.a AS a FROM (SELECT x.a AS a FROM x AS x) AS _q_0; diff --git a/tests/fixtures/optimizer/unnest_subqueries.sql b/tests/fixtures/optimizer/unnest_subqueries.sql index dc373a0..a444945 100644 --- a/tests/fixtures/optimizer/unnest_subqueries.sql +++ b/tests/fixtures/optimizer/unnest_subqueries.sql @@ -30,14 +30,14 @@ CROSS JOIN ( SELECT SUM(y.a) AS a FROM y -) AS "_u_0" +) AS _u_0 LEFT JOIN ( SELECT y.a AS a FROM y GROUP BY y.a -) AS "_u_1" +) AS _u_1 ON x.a = "_u_1"."a" LEFT JOIN ( SELECT @@ -45,7 +45,7 @@ LEFT JOIN ( FROM y GROUP BY y.b -) AS "_u_2" +) AS _u_2 ON x.a = "_u_2"."b" LEFT JOIN ( SELECT @@ -53,7 +53,7 @@ LEFT JOIN ( FROM y GROUP BY y.a -) AS "_u_3" +) AS _u_3 ON x.a = "_u_3"."a" LEFT JOIN ( SELECT @@ -64,8 +64,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_4" - ON x.a = "_u_4"."_u_5" +) AS _u_4 + ON x.a = _u_4._u_5 LEFT JOIN ( SELECT SUM(y.b) AS b, @@ -75,8 +75,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_6" - ON x.a = "_u_6"."_u_7" +) AS _u_6 + ON x.a = _u_6._u_7 LEFT JOIN ( SELECT y.a AS a @@ -85,8 +85,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_8" - ON "_u_8".a = x.a +) AS _u_8 + ON _u_8.a = x.a LEFT JOIN ( SELECT y.a AS a @@ -95,8 +95,8 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_9" - ON "_u_9".a = x.a +) AS _u_9 + ON _u_9.a = x.a LEFT JOIN ( SELECT ARRAY_AGG(y.a) AS a, @@ -106,8 +106,8 @@ LEFT JOIN ( TRUE GROUP BY y.b -) AS "_u_10" - ON "_u_10"."_u_11" = x.a +) AS _u_10 + ON _u_10._u_11 = x.a LEFT JOIN ( SELECT SUM(y.a) AS a, @@ -118,8 +118,8 @@ LEFT JOIN ( TRUE AND TRUE AND TRUE GROUP BY y.a -) AS "_u_12" - ON "_u_12"."_u_13" = x.a AND "_u_12"."_u_13" = x.b +) AS _u_12 + ON _u_12._u_13 = x.a AND _u_12._u_13 = x.b LEFT JOIN ( SELECT y.a AS a @@ -128,38 +128,38 @@ LEFT JOIN ( TRUE GROUP BY y.a -) AS "_u_15" - ON x.a = "_u_15".a +) AS _u_15 + ON x.a = _u_15.a WHERE - x.a = "_u_0".a + x.a = _u_0.a AND NOT "_u_1"."a" IS NULL AND NOT "_u_2"."b" IS NULL AND NOT "_u_3"."a" IS NULL AND ( - x.a = "_u_4".b AND NOT "_u_4"."_u_5" IS NULL + x.a = _u_4.b AND NOT _u_4._u_5 IS NULL ) AND ( - x.a > "_u_6".b AND NOT "_u_6"."_u_7" IS NULL + x.a > _u_6.b AND NOT _u_6._u_7 IS NULL ) AND ( - None = "_u_8".a AND NOT "_u_8".a IS NULL + None = _u_8.a AND NOT _u_8.a IS NULL ) AND NOT ( - x.a = "_u_9".a AND NOT "_u_9".a IS NULL + x.a = _u_9.a AND NOT _u_9.a IS NULL ) AND ( - ARRAY_ANY("_u_10".a, _x -> _x = x.a) AND NOT "_u_10"."_u_11" IS NULL + ARRAY_ANY(_u_10.a, _x -> _x = x.a) AND NOT _u_10._u_11 IS NULL ) AND ( ( ( - x.a < "_u_12".a AND NOT "_u_12"."_u_13" IS NULL - ) AND NOT "_u_12"."_u_13" IS NULL + x.a < _u_12.a AND NOT _u_12._u_13 IS NULL + ) AND NOT _u_12._u_13 IS NULL ) - AND ARRAY_ANY("_u_12"."_u_14", "_x" -> "_x" <> x.d) + AND ARRAY_ANY(_u_12._u_14, "_x" -> _x <> x.d) ) AND ( - NOT "_u_15".a IS NULL AND NOT "_u_15".a IS NULL + NOT _u_15.a IS NULL AND NOT _u_15.a IS NULL ) AND x.a IN ( SELECT diff --git a/tests/test_build.py b/tests/test_build.py index b014a3a..a1a268d 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -481,6 +481,19 @@ class TestBuild(unittest.TestCase): ), (lambda: exp.delete("y", where="x > 1"), "DELETE FROM y WHERE x > 1"), (lambda: exp.delete("y", where=exp.and_("x > 1")), "DELETE FROM y WHERE x > 1"), + ( + lambda: select("AVG(a) OVER b") + .from_("table") + .window("b AS (PARTITION BY c ORDER BY d)"), + "SELECT AVG(a) OVER b FROM table WINDOW b AS (PARTITION BY c ORDER BY d)", + ), + ( + lambda: select("AVG(a) OVER b", "MIN(c) OVER d") + .from_("table") + .window("b AS (PARTITION BY e ORDER BY f)") + .window("d AS (PARTITION BY g ORDER BY h)"), + "SELECT AVG(a) OVER b, MIN(c) OVER d FROM table WINDOW b AS (PARTITION BY e ORDER BY f), d AS (PARTITION BY g ORDER BY h)", + ), ]: with self.subTest(sql): self.assertEqual(expression().sql(dialect[0] if dialect else None), sql) diff --git a/tests/test_executor.py b/tests/test_executor.py index 4fe6399..b705551 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -74,7 +74,7 @@ class TestExecutor(unittest.TestCase): ) return expression - for i, (sql, _) in enumerate(self.sqls[0:18]): + for i, (sql, _) in enumerate(self.sqls): with self.subTest(f"tpch-h {i + 1}"): a = self.cached_execute(sql) sql = parse_one(sql).transform(to_csv).sql(pretty=True) diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 0e13ade..1e23983 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,4 +1,5 @@ import datetime +import math import unittest from sqlglot import alias, exp, parse_one @@ -491,7 +492,7 @@ class TestExpressions(unittest.TestCase): self.assertEqual(alias("foo", "bar-1").sql(), 'foo AS "bar-1"') self.assertEqual(alias("foo", "bar_1").sql(), "foo AS bar_1") self.assertEqual(alias("foo * 2", "2bar").sql(), 'foo * 2 AS "2bar"') - self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS "_bar"') + self.assertEqual(alias('"foo"', "_bar").sql(), '"foo" AS _bar') self.assertEqual(alias("foo", "bar", quoted=True).sql(), 'foo AS "bar"') def test_unit(self): @@ -503,6 +504,8 @@ class TestExpressions(unittest.TestCase): def test_identifier(self): self.assertTrue(exp.to_identifier('"x"').quoted) self.assertFalse(exp.to_identifier("x").quoted) + self.assertTrue(exp.to_identifier("foo ").quoted) + self.assertFalse(exp.to_identifier("_x").quoted) def test_function_normalizer(self): self.assertEqual(parse_one("HELLO()").sql(normalize_functions="lower"), "hello()") @@ -549,14 +552,15 @@ class TestExpressions(unittest.TestCase): ([1, "2", None], "ARRAY(1, '2', NULL)"), ({"x": None}, "MAP('x', NULL)"), ( - datetime.datetime(2022, 10, 1, 1, 1, 1), - "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')", + datetime.datetime(2022, 10, 1, 1, 1, 1, 1), + "TIME_STR_TO_TIME('2022-10-01T01:01:01.000001+00:00')", ), ( datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), - "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", + "TIME_STR_TO_TIME('2022-10-01T01:01:01+00:00')", ), (datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"), + (math.nan, "NULL"), ]: with self.subTest(value): self.assertEqual(exp.convert(value).sql(), expected) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 0c5f6cd..1c97be7 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -164,9 +164,6 @@ class TestOptimizer(unittest.TestCase): with self.assertRaises(OptimizeError): optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema) - def test_quote_identities(self): - self.check_file("quote_identities", optimizer.quote_identities.quote_identities) - def test_lower_identities(self): self.check_file("lower_identities", optimizer.lower_identities.lower_identities) @@ -555,3 +552,29 @@ FROM READ_CSV('tests/fixtures/optimizer/tpc-h/nation.csv.gz', 'delimiter', '|') parse_one(f"SELECT {func}(x.{col}) AS _col_0 FROM x AS x"), schema=schema ) self.assertEqual(expression.expressions[0].type.this, target_type) + + def test_recursive_cte(self): + query = parse_one( + """ + with recursive t(n) AS + ( + select 1 + union all + select n + 1 + FROM t + where n < 3 + ), y AS ( + select n + FROM t + union all + select n + 1 + FROM y + where n < 2 + ) + select * from y + """ + ) + + scope_t, scope_y = build_scope(query).cte_scopes + self.assertEqual(set(scope_t.cte_sources), {"t"}) + self.assertEqual(set(scope_y.cte_sources), {"t", "y"}) diff --git a/tests/test_parser.py b/tests/test_parser.py index 0be15e4..ae2e4cd 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -76,6 +76,9 @@ class TestParser(unittest.TestCase): tables = [t.sql() for t in parse_one("select * from a, b.c, .d").find_all(exp.Table)] self.assertEqual(tables, ["a", "b.c", "d"]) + def test_union_order(self): + self.assertIsInstance(parse_one("SELECT * FROM (SELECT 1) UNION SELECT 2"), exp.Union) + def test_select(self): self.assertIsNotNone(parse_one("select 1 natural")) self.assertIsNotNone(parse_one("select * from (select 1) x order by x.y").args["order"]) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 0bcd2ca..cfb8d2b 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -40,17 +40,17 @@ class TestTime(unittest.TestCase): self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a) a, b FROM x ORDER BY c DESC", - 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', ) self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a) a, b FROM x", - 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS "_row_number" FROM x) WHERE "_row_number" = 1', + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a) AS _row_number FROM x) WHERE "_row_number" = 1', ) self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (a, b) a, b FROM x ORDER BY c DESC", - 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS "_row_number" FROM x) WHERE "_row_number" = 1', + 'SELECT a, b FROM (SELECT a, b, ROW_NUMBER() OVER (PARTITION BY a, b ORDER BY c DESC) AS _row_number FROM x) WHERE "_row_number" = 1', ) self.validate( eliminate_distinct_on, @@ -60,5 +60,5 @@ class TestTime(unittest.TestCase): self.validate( eliminate_distinct_on, "SELECT DISTINCT ON (_row_number) _row_number FROM x ORDER BY c DESC", - 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS "_row_number_2" FROM x) WHERE "_row_number_2" = 1', + 'SELECT _row_number FROM (SELECT _row_number, ROW_NUMBER() OVER (PARTITION BY _row_number ORDER BY c DESC) AS _row_number_2 FROM x) WHERE "_row_number_2" = 1', ) diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 7bf53e5..9253ded 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -28,7 +28,7 @@ class TestTranspile(unittest.TestCase): self.assertEqual(transpile("SELECT 1 current_datetime")[0], "SELECT 1 AS current_datetime") self.assertEqual(transpile("SELECT 1 row")[0], "SELECT 1 AS row") - for key in ("union", "filter", "over", "from", "join"): + for key in ("union", "over", "from", "join"): with self.subTest(f"alias {key}"): self.validate(f"SELECT x AS {key}", f"SELECT x AS {key}") self.validate(f'SELECT x "{key}"', f'SELECT x AS "{key}"') @@ -263,6 +263,25 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", "WITH a AS (SELECT 1), WITH b AS (SELECT 2) SELECT *", "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT *", ) + self.validate( + "WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2", + "WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2", + ) + + def test_alter(self): + self.validate( + "ALTER TABLE integers ADD k INTEGER", + "ALTER TABLE integers ADD COLUMN k INT", + ) + self.validate("ALTER TABLE integers DROP k", "ALTER TABLE integers DROP COLUMN k") + self.validate( + "ALTER TABLE integers ALTER i SET DATA TYPE VARCHAR", + "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR", + ) + self.validate( + "ALTER TABLE integers ALTER i TYPE VARCHAR COLLATE foo USING bar", + "ALTER TABLE integers ALTER COLUMN i TYPE VARCHAR COLLATE foo USING bar", + ) def test_time(self): self.validate("TIMESTAMP '2020-01-01'", "CAST('2020-01-01' AS TIMESTAMP)") @@ -403,6 +422,14 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""", with self.subTest(sql): self.assertEqual(transpile(sql)[0], sql.strip()) + def test_normalize_name(self): + self.assertEqual( + transpile("cardinality(x)", read="presto", write="presto", normalize_functions="lower")[ + 0 + ], + "cardinality(x)", + ) + def test_partial(self): for sql in load_sql_fixtures("partial.sql"): with self.subTest(sql): |