From 7ff5bab54e3298dd89132706f6adee17f5164f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sat, 5 Nov 2022 19:41:12 +0100 Subject: Merging upstream version 9.0.6. Signed-off-by: Daniel Baumann --- .github/workflows/python-package.yml | 2 +- .github/workflows/python-publish.yml | 2 +- CONTRIBUTING.md | 45 +++++ README.md | 205 ++++++++++++--------- dev-requirements.txt | 1 + sqlglot/__init__.py | 4 +- sqlglot/dialects/__init__.py | 1 + sqlglot/dialects/databricks.py | 21 +++ sqlglot/dialects/dialect.py | 13 ++ sqlglot/dialects/hive.py | 2 + sqlglot/dialects/presto.py | 1 + sqlglot/dialects/snowflake.py | 2 + sqlglot/dialects/sqlite.py | 1 + sqlglot/dialects/tsql.py | 78 ++++++-- sqlglot/expressions.py | 151 ++++++++++----- sqlglot/generator.py | 20 +- sqlglot/optimizer/qualify_columns.py | 21 ++- sqlglot/optimizer/scope.py | 2 +- sqlglot/parser.py | 83 ++++++--- sqlglot/time.py | 1 - sqlglot/tokens.py | 4 + tests/dataframe/integration/dataframe_validator.py | 5 +- tests/dataframe/unit/test_column.py | 3 +- tests/dataframe/unit/test_functions.py | 4 +- tests/dialects/test_databricks.py | 33 ++++ tests/dialects/test_dialect.py | 22 +++ tests/dialects/test_snowflake.py | 12 ++ tests/dialects/test_tsql.py | 94 +++++++++- tests/fixtures/optimizer/qualify_columns.sql | 17 ++ tests/test_build.py | 36 ++++ tests/test_expressions.py | 6 +- 31 files changed, 695 insertions(+), 197 deletions(-) create mode 100644 CONTRIBUTING.md create mode 100644 sqlglot/dialects/databricks.py create mode 100644 tests/dialects/test_databricks.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index a3f151b..3b6fdc6 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.7", "3.8", "3.9", "3.10"] + python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 061d863..2d112b9 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.9' + python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..1d3b822 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,45 @@ +# Contributing to [SQLGlot](https://github.com/tobymao/sqlglot/blob/main/README.md) + +SQLGLot is open source software. We value feedback and we want to make contributing to this project as +easy and transparent as possible, whether it's: + +- Reporting a bug +- Discussing the current state of the code +- Submitting a fix +- Proposing new features + +## We develop with Github +We use github to host code, to track issues and feature requests, as well as accept pull requests. + +## Submitting code changes +Pull requests are the best way to propose changes to the codebase. We actively welcome your pull requests: + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite & linter [checks](https://github.com/tobymao/sqlglot/blob/main/README.md#run-tests-and-lint) pass. +5. Issue that pull request and wait for it to be reviewed by a maintainer or contributor! + +## Report bugs using Github's [issues](https://github.com/tobymao/sqlglot/issues) +We use GitHub issues to track public bugs. Report a bug by [opening a new issue](). + +**Great Bug Reports** tend to have: + +- A quick summary and/or background +- Steps to reproduce + - Be specific! + - Give sample code if you can +- What you expected would happen +- What actually happens +- Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) + +## Start a discussion using Github's [discussions](https://github.com/tobymao/sqlglot/discussions) +[We use GitHub discussions](https://github.com/tobymao/sqlglot/discussions/190) to discuss about the current state +of the code. If you want to propose a new feature, this is the right place to do it! Just start a discussion, and +let us know why you think this feature would be a good addition to SQLGlot (by possibly including some usage examples). + +## [License](https://github.com/tobymao/sqlglot/blob/main/LICENSE) +By contributing, you agree that your contributions will be licensed under its MIT License. + +## References +This document was adapted from [briandk's template](https://gist.github.com/briandk/3d2e8b3ec8daf5a27a62). diff --git a/README.md b/README.md index b13ac2e..0d7e429 100644 --- a/README.md +++ b/README.md @@ -8,51 +8,96 @@ You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) qu Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. +Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started! + +## Table of Contents + +* [Install](#install) +* [Documentation](#documentation) +* [Run Tests & Lint](#run-tests-and-lint) +* [Examples](#examples) + * [Formatting and Transpiling](#formatting-and-transpiling) + * [Metadata](#metadata) + * [Parser Errors](#parser-errors) + * [Unsupported Errors](#unsupported-errors) + * [Build and Modify SQL](#build-and-modify-sql) + * [SQL Optimizer](#sql-optimizer) + * [SQL Annotations](#sql-annotations) + * [AST Introspection](#ast-introspection) + * [AST Diff](#ast-diff) + * [Custom Dialects](#custom-dialects) +* [Benchmarks](#benchmarks) +* [Optional Dependencies](#optional-dependencies) + ## Install -From PyPI + +From PyPI: ``` pip3 install sqlglot ``` -Or with a local checkout +Or with a local checkout: ``` pip3 install -e . ``` +Requirements for development (optional): + +``` +pip3 install -r dev-requirements.txt +``` + +## Documentation + +SQLGlot's uses [pdocs](https://pdoc.dev/) to serve its API documentation: + +``` +pdoc sqlglot --docformat google +``` + +## Run Tests and Lint + +``` +# set `SKIP_INTEGRATION=1` to skip integration tests +./run_checks.sh +``` + + ## Examples -Easily translate from one dialect to another. For example, date/time functions vary from dialects and can be hard to deal with. + +### Formatting and Transpiling + +Easily translate from one dialect to another. For example, date/time functions vary from dialects and can be hard to deal with: ```python import sqlglot -sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read='duckdb', write='hive') +sqlglot.transpile("SELECT EPOCH_MS(1618088028295)", read="duckdb", write="hive")[0] ``` ```sql -SELECT FROM_UNIXTIME(1618088028295 / 1000) +'SELECT FROM_UNIXTIME(1618088028295 / 1000)' ``` -SQLGlot can even translate custom time formats. +SQLGlot can even translate custom time formats: + ```python import sqlglot -sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read='duckdb', write='hive') +sqlglot.transpile("SELECT STRFTIME(x, '%y-%-m-%S')", read="duckdb", write="hive")[0] ``` ```sql -SELECT DATE_FORMAT(x, 'yy-M-ss')" +"SELECT DATE_FORMAT(x, 'yy-M-ss')" ``` -## Formatting and Transpiling -Read in a SQL statement with a CTE and CASTING to a REAL and then transpiling to Spark. - -Spark uses backticks as identifiers and the REAL type is transpiled to FLOAT. +As another example, let's suppose that we want to read in a SQL query that contains a CTE and a cast to `REAL`, and then transpile it to Spark, which uses backticks as identifiers and `FLOAT` instead of `REAL`: ```python import sqlglot sql = """WITH baz AS (SELECT a, c FROM foo WHERE a = 1) SELECT f.a, b.b, baz.c, CAST("b"."a" AS REAL) d FROM foo f JOIN bar b ON f.a = b.a LEFT JOIN baz ON f.a = baz.a""" -sqlglot.transpile(sql, write='spark', identify=True, pretty=True)[0] +print(sqlglot.transpile(sql, write="spark", identify=True, pretty=True)[0]) ``` ```sql @@ -76,9 +121,9 @@ LEFT JOIN `baz` ON `f`.`a` = `baz`.`a` ``` -## Metadata +### Metadata -You can explore SQL with expression helpers to do things like find columns and tables. +You can explore SQL with expression helpers to do things like find columns and tables: ```python from sqlglot import parse_one, exp @@ -97,34 +142,38 @@ for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table): print(table.name) ``` -## Parser Errors -A syntax error will result in a parser error. +### Parser Errors + +A syntax error will result in a parser error: + ```python -transpile("SELECT foo( FROM bar") +import sqlglot +sqlglot.transpile("SELECT foo( FROM bar") ``` +``` sqlglot.errors.ParseError: Expecting ). Line 1, Col: 13. - select foo( __FROM__ bar + select foo( FROM bar + ~~~~ +``` -## Unsupported Errors -Presto APPROX_DISTINCT supports the accuracy argument which is not supported in Spark. +### Unsupported Errors + +Presto `APPROX_DISTINCT` supports the accuracy argument which is not supported in Hive: ```python -transpile( - 'SELECT APPROX_DISTINCT(a, 0.1) FROM foo', - read='presto', - write='spark', -) +import sqlglot +sqlglot.transpile("SELECT APPROX_DISTINCT(a, 0.1) FROM foo", read="presto", write="hive") ``` ```sql -WARNING:root:APPROX_COUNT_DISTINCT does not support accuracy - -SELECT APPROX_COUNT_DISTINCT(a) FROM foo +APPROX_COUNT_DISTINCT does not support accuracy +'SELECT APPROX_COUNT_DISTINCT(a) FROM foo' ``` -## Build and Modify SQL -SQLGlot supports incrementally building sql expressions. +### Build and Modify SQL + +SQLGlot supports incrementally building sql expressions: ```python from sqlglot import select, condition @@ -132,21 +181,20 @@ from sqlglot import select, condition where = condition("x=1").and_("y=1") select("*").from_("y").where(where).sql() ``` -Which outputs: + ```sql -SELECT * FROM y WHERE x = 1 AND y = 1 +'SELECT * FROM y WHERE x = 1 AND y = 1' ``` You can also modify a parsed tree: ```python from sqlglot import parse_one - parse_one("SELECT x FROM y").from_("z").sql() ``` -Which outputs: + ```sql -SELECT x FROM y, z +'SELECT x FROM y, z' ``` There is also a way to recursively transform the parsed tree by applying a mapping function to each tree node: @@ -164,68 +212,64 @@ def transformer(node): transformed_tree = expression_tree.transform(transformer) transformed_tree.sql() ``` -Which outputs: + ```sql -SELECT FUN(a) FROM x +'SELECT FUN(a) FROM x' ``` -## SQL Optimizer +### SQL Optimizer -SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. +SQLGlot can rewrite queries into an "optimized" form. It performs a variety of [techniques](sqlglot/optimizer/optimizer.py) to create a new canonical AST. This AST can be used to standardize queries or provide the foundations for implementing an actual engine. For example: ```python import sqlglot from sqlglot.optimizer import optimize ->>> -optimize( - sqlglot.parse_one(""" - SELECT A OR (B OR (C AND D)) - FROM x - WHERE Z = date '2021-01-01' + INTERVAL '1' month OR 1 = 0 - """), - schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}} -).sql(pretty=True) - -""" +print( + optimize( + sqlglot.parse_one(""" + SELECT A OR (B OR (C AND D)) + FROM x + WHERE Z = date '2021-01-01' + INTERVAL '1' month OR 1 = 0 + """), + schema={"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}} + ).sql(pretty=True) +) +``` + +``` SELECT ( - "x"."A" - OR "x"."B" - OR "x"."C" - ) - AND ( - "x"."A" - OR "x"."B" - OR "x"."D" + "x"."A" OR "x"."B" OR "x"."C" + ) AND ( + "x"."A" OR "x"."B" OR "x"."D" ) AS "_col_0" FROM "x" AS "x" WHERE "x"."Z" = CAST('2021-02-01' AS DATE) -""" ``` -## SQL Annotations +### SQL Annotations SQLGlot supports annotations in the sql expression. This is an experimental feature that is not part of any of the SQL standards but it can be useful when needing to annotate what a selected field is supposed to be. Below is an example: ```sql SELECT - user #primary_key, + user # primary_key, country FROM users ``` -SQL annotations are currently incompatible with MySQL, which uses the `#` character to introduce comments. - -## AST Introspection +### AST Introspection -You can see the AST version of the sql by calling repr. +You can see the AST version of the sql by calling `repr`: ```python from sqlglot import parse_one -repr(parse_one("SELECT a + 1 AS z")) +print(repr(parse_one("SELECT a + 1 AS z"))) +``` +```python (SELECT expressions: (ALIAS this: (ADD this: @@ -235,14 +279,16 @@ repr(parse_one("SELECT a + 1 AS z")) (IDENTIFIER this: z, quoted: False))) ``` -## AST Diff +### AST Diff -SQLGlot can calculate the difference between two expressions and output changes in a form of a sequence of actions needed to transform a source expression into a target one. +SQLGlot can calculate the difference between two expressions and output changes in a form of a sequence of actions needed to transform a source expression into a target one: ```python from sqlglot import diff, parse_one diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d")) +``` +```python [ Remove(expression=(ADD this: (COLUMN this: @@ -261,9 +307,9 @@ diff(parse_one("SELECT a + b, c, d"), parse_one("SELECT c, a - b, d")) ] ``` -## Custom Dialects +### Custom Dialects -[Dialects](sqlglot/dialects) can be added by subclassing Dialect. +[Dialects](sqlglot/dialects) can be added by subclassing `Dialect`: ```python from sqlglot import exp @@ -298,8 +344,11 @@ class Custom(Dialect): exp.DataType.Type.TEXT: "STRING", } +print(Dialect["custom"]) +``` -Dialect["custom"] +```python + ``` ## Benchmarks @@ -314,18 +363,10 @@ Dialect["custom"] | crazy | 0.03751 (1.0) | 0.03471 (0.925) | 11.0796 (295.3) | 1.03355 (27.55) | 0.00529 (0.141) | -## Run Tests and Lint -``` -pip install -r dev-requirements.txt -# set `SKIP_INTEGRATION=1` to skip integration tests -./run_checks.sh -``` - ## Optional Dependencies -SQLGlot uses [dateutil](https://github.com/dateutil/dateutil) to simplify literal timedelta expressions. The optimizer will not simplify expressions like + +SQLGlot uses [dateutil](https://github.com/dateutil/dateutil) to simplify literal timedelta expressions. The optimizer will not simplify expressions like the following if the module cannot be found: ```sql x + interval '1' month ``` - -if the module cannot be found. diff --git a/dev-requirements.txt b/dev-requirements.txt index 336ecf4..aa7d31f 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -6,3 +6,4 @@ mypy pandas pyspark python-dateutil +pdoc diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py index a780f96..d6e18fd 100644 --- a/sqlglot/__init__.py +++ b/sqlglot/__init__.py @@ -1,3 +1,5 @@ +"""## Python SQL parser, transpiler and optimizer.""" + from sqlglot import expressions as exp from sqlglot.dialects import Dialect, Dialects from sqlglot.diff import diff @@ -24,7 +26,7 @@ from sqlglot.parser import Parser from sqlglot.schema import MappingSchema from sqlglot.tokens import Tokenizer, TokenType -__version__ = "9.0.3" +__version__ = "9.0.6" pretty = False diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py index 0f80723..0816831 100644 --- a/sqlglot/dialects/__init__.py +++ b/sqlglot/dialects/__init__.py @@ -1,5 +1,6 @@ from sqlglot.dialects.bigquery import BigQuery from sqlglot.dialects.clickhouse import ClickHouse +from sqlglot.dialects.databricks import Databricks from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.dialects.duckdb import DuckDB from sqlglot.dialects.hive import Hive diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py new file mode 100644 index 0000000..9dc3c38 --- /dev/null +++ b/sqlglot/dialects/databricks.py @@ -0,0 +1,21 @@ +from sqlglot import exp +from sqlglot.dialects.dialect import parse_date_delta +from sqlglot.dialects.spark import Spark +from sqlglot.dialects.tsql import generate_date_delta_with_unit_sql + + +class Databricks(Spark): + class Parser(Spark.Parser): + FUNCTIONS = { + **Spark.Parser.FUNCTIONS, + "DATEADD": parse_date_delta(exp.DateAdd), + "DATE_ADD": parse_date_delta(exp.DateAdd), + "DATEDIFF": parse_date_delta(exp.DateDiff), + } + + class Generator(Spark.Generator): + TRANSFORMS = { + **Spark.Generator.TRANSFORMS, + exp.DateAdd: generate_date_delta_with_unit_sql, + exp.DateDiff: generate_date_delta_with_unit_sql, + } diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 46661cf..33985a7 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -28,6 +28,7 @@ class Dialects(str, Enum): TABLEAU = "tableau" TRINO = "trino" TSQL = "tsql" + DATABRICKS = "databricks" class _Dialect(type): @@ -331,3 +332,15 @@ def create_with_partitions_sql(self, expression): expression.set("this", schema) return self.create_sql(expression) + + +def parse_date_delta(exp_class, unit_mapping=None): + def inner_func(args): + unit_based = len(args) == 3 + this = list_get(args, 2) if unit_based else list_get(args, 0) + expression = list_get(args, 1) if unit_based else list_get(args, 1) + unit = list_get(args, 0) if unit_based else exp.Literal.string("DAY") + unit = unit_mapping.get(unit.name.lower(), unit) if unit_mapping else unit + return exp_class(this=this, expression=expression, unit=unit) + + return inner_func diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 63fdb85..03049ff 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -111,6 +111,7 @@ def _unnest_to_explode_sql(self, expression): self.sql( exp.Lateral( this=udtf(this=expression), + view=True, alias=exp.TableAlias(this=alias.this, columns=[column]), ) ) @@ -283,6 +284,7 @@ class Hive(Dialect): exp.UnixToTime: rename_func("FROM_UNIXTIME"), exp.UnixToTimeStr: rename_func("FROM_UNIXTIME"), exp.PartitionedByProperty: lambda self, e: f"PARTITIONED BY {self.sql(e, 'value')}", + exp.NumberToStr: rename_func("FORMAT_NUMBER"), } WITH_PROPERTIES = {exp.AnonymousProperty} diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 41c0db1..a2d392c 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -115,6 +115,7 @@ class Presto(Dialect): class Tokenizer(Tokenizer): KEYWORDS = { **Tokenizer.KEYWORDS, + "VARBINARY": TokenType.BINARY, "ROW": TokenType.STRUCT, } diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 627258f..3b97e6d 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -188,6 +188,8 @@ class Snowflake(Dialect): } class Generator(Generator): + CREATE_TRANSIENT = True + TRANSFORMS = { **Generator.TRANSFORMS, exp.ArrayConcat: rename_func("ARRAY_CAT"), diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index cfdbe1b..62b7617 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -20,6 +20,7 @@ class SQLite(Dialect): KEYWORDS = { **Tokenizer.KEYWORDS, + "VARBINARY": TokenType.BINARY, "AUTOINCREMENT": TokenType.AUTO_INCREMENT, } diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 107ace7..0f93c75 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -1,5 +1,7 @@ +import re + from sqlglot import exp -from sqlglot.dialects.dialect import Dialect, rename_func +from sqlglot.dialects.dialect import Dialect, parse_date_delta, rename_func from sqlglot.expressions import DataType from sqlglot.generator import Generator from sqlglot.helper import list_get @@ -27,6 +29,11 @@ DATE_DELTA_INTERVAL = { } +DATE_FMT_RE = re.compile("([dD]{1,2})|([mM]{1,2})|([yY]{1,4})|([hH]{1,2})|([sS]{1,2})") +# N = Numeric, C=Currency +TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"} + + def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): def _format_time(args): return exp_class( @@ -42,26 +49,40 @@ def tsql_format_time_lambda(exp_class, full_format_mapping=None, default=None): return _format_time -def parse_date_delta(exp_class): - def inner_func(args): - unit = DATE_DELTA_INTERVAL.get(list_get(args, 0).name.lower(), "day") - return exp_class(this=list_get(args, 2), expression=list_get(args, 1), unit=unit) - - return inner_func +def parse_format(args): + fmt = list_get(args, 1) + number_fmt = fmt.name in TRANSPILE_SAFE_NUMBER_FMT or not DATE_FMT_RE.search(fmt.this) + if number_fmt: + return exp.NumberToStr(this=list_get(args, 0), format=fmt) + return exp.TimeToStr( + this=list_get(args, 0), + format=exp.Literal.string( + format_time(fmt.name, TSQL.format_time_mapping) + if len(fmt.name) == 1 + else format_time(fmt.name, TSQL.time_mapping) + ), + ) -def generate_date_delta(self, e): +def generate_date_delta_with_unit_sql(self, e): func = "DATEADD" if isinstance(e, exp.DateAdd) else "DATEDIFF" return f"{func}({self.format_args(e.text('unit'), e.expression, e.this)})" +def generate_format_sql(self, e): + fmt = ( + e.args["format"] + if isinstance(e, exp.NumberToStr) + else exp.Literal.string(format_time(e.text("format"), TSQL.inverse_time_mapping)) + ) + return f"FORMAT({self.format_args(e.this, fmt)})" + + class TSQL(Dialect): null_ordering = "nulls_are_small" time_format = "'yyyy-mm-dd hh:mm:ss'" time_mapping = { - "yyyy": "%Y", - "yy": "%y", "year": "%Y", "qq": "%q", "q": "%q", @@ -101,6 +122,8 @@ class TSQL(Dialect): "H": "%-H", "h": "%-I", "S": "%f", + "yyyy": "%Y", + "yy": "%y", } convert_format_mapping = { @@ -143,6 +166,27 @@ class TSQL(Dialect): "120": "%Y-%m-%d %H:%M:%S", "121": "%Y-%m-%d %H:%M:%S.%f", } + # not sure if complete + format_time_mapping = { + "y": "%B %Y", + "d": "%m/%d/%Y", + "H": "%-H", + "h": "%-I", + "s": "%Y-%m-%d %H:%M:%S", + "D": "%A,%B,%Y", + "f": "%A,%B,%Y %-I:%M %p", + "F": "%A,%B,%Y %-I:%M:%S %p", + "g": "%m/%d/%Y %-I:%M %p", + "G": "%m/%d/%Y %-I:%M:%S %p", + "M": "%B %-d", + "m": "%B %-d", + "O": "%Y-%m-%dT%H:%M:%S", + "u": "%Y-%M-%D %H:%M:%S%z", + "U": "%A, %B %D, %Y %H:%M:%S%z", + "T": "%-I:%M:%S %p", + "t": "%-I:%M", + "Y": "%a %Y", + } class Tokenizer(Tokenizer): IDENTIFIERS = ['"', ("[", "]")] @@ -166,6 +210,7 @@ class TSQL(Dialect): "SQL_VARIANT": TokenType.VARIANT, "NVARCHAR(MAX)": TokenType.TEXT, "VARCHAR(MAX)": TokenType.TEXT, + "TOP": TokenType.TOP, } class Parser(Parser): @@ -173,8 +218,8 @@ class TSQL(Dialect): **Parser.FUNCTIONS, "CHARINDEX": exp.StrPosition.from_arg_list, "ISNULL": exp.Coalesce.from_arg_list, - "DATEADD": parse_date_delta(exp.DateAdd), - "DATEDIFF": parse_date_delta(exp.DateDiff), + "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), "DATENAME": tsql_format_time_lambda(exp.TimeToStr, full_format_mapping=True), "DATEPART": tsql_format_time_lambda(exp.TimeToStr), "GETDATE": exp.CurrentDate.from_arg_list, @@ -182,6 +227,7 @@ class TSQL(Dialect): "LEN": exp.Length.from_arg_list, "REPLICATE": exp.Repeat.from_arg_list, "JSON_VALUE": exp.JSONExtractScalar.from_arg_list, + "FORMAT": parse_format, } VAR_LENGTH_DATATYPES = { @@ -194,7 +240,7 @@ class TSQL(Dialect): def _parse_convert(self, strict): to = self._parse_types() self._match(TokenType.COMMA) - this = self._parse_field() + this = self._parse_column() # Retrieve length of datatype and override to default if not specified if list_get(to.expressions, 0) is None and to.this in self.VAR_LENGTH_DATATYPES: @@ -238,8 +284,10 @@ class TSQL(Dialect): TRANSFORMS = { **Generator.TRANSFORMS, - exp.DateAdd: lambda self, e: generate_date_delta(self, e), - exp.DateDiff: lambda self, e: generate_date_delta(self, e), + exp.DateAdd: generate_date_delta_with_unit_sql, + exp.DateDiff: generate_date_delta_with_unit_sql, exp.CurrentDate: rename_func("GETDATE"), exp.If: rename_func("IIF"), + exp.NumberToStr: generate_format_sql, + exp.TimeToStr: generate_format_sql, } diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index eb7854a..1691d85 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -443,7 +443,7 @@ class Condition(Expression): 'x = 1 AND y = 1' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect (str): the dialect used to parse the input expression. opts (kwargs): other options to use to parse the input expressions. @@ -462,7 +462,7 @@ class Condition(Expression): 'x = 1 OR y = 1' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. dialect (str): the dialect used to parse the input expression. opts (kwargs): other options to use to parse the input expressions. @@ -523,7 +523,7 @@ class Unionable(Expression): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - expression (str or Expression): the SQL code string. + expression (str | Expression): the SQL code string. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -543,7 +543,7 @@ class Unionable(Expression): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - expression (str or Expression): the SQL code string. + expression (str | Expression): the SQL code string. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -563,7 +563,7 @@ class Unionable(Expression): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - expression (str or Expression): the SQL code string. + expression (str | Expression): the SQL code string. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -612,6 +612,7 @@ class Create(Expression): "exists": False, "properties": False, "temporary": False, + "transient": False, "replace": False, "unique": False, "materialized": False, @@ -910,7 +911,7 @@ class Join(Expression): 'JOIN x ON y = 1' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. append (bool): if `True`, AND the new expressions to any existing expression. @@ -937,9 +938,45 @@ class Join(Expression): return join + def using(self, *expressions, append=True, dialect=None, copy=True, **opts): + """ + Append to or set the USING expressions. + + Example: + >>> import sqlglot + >>> sqlglot.parse_one("JOIN x", into=Join).using("foo", "bla").sql() + 'JOIN x USING (foo, bla)' + + Args: + *expressions (str | Expression): the SQL code strings to parse. + If an `Expression` instance is passed, it will be used as-is. + append (bool): if `True`, concatenate the new expressions to the existing "using" list. + Otherwise, this resets the expression. + dialect (str): the dialect used to parse the input expressions. + copy (bool): if `False`, modify this expression instance in-place. + opts (kwargs): other options to use to parse the input expressions. + + Returns: + Join: the modified join expression. + """ + join = _apply_list_builder( + *expressions, + instance=self, + arg="using", + append=append, + dialect=dialect, + copy=copy, + **opts, + ) + + if join.kind == "CROSS": + join.set("kind", None) + + return join + class Lateral(UDTF): - arg_types = {"this": True, "outer": False, "alias": False} + arg_types = {"this": True, "view": False, "outer": False, "alias": False} # Clickhouse FROM FINAL modifier @@ -1093,7 +1130,7 @@ class Subqueryable(Unionable): 'SELECT x FROM (SELECT x FROM tbl)' Args: - alias (str or Identifier): an optional alias for the subquery + alias (str | Identifier): an optional alias for the subquery copy (bool): if `False`, modify this expression instance in-place. Returns: @@ -1138,9 +1175,9 @@ class Subqueryable(Unionable): 'WITH tbl2 AS (SELECT * FROM tbl) SELECT x FROM tbl2' Args: - alias (str or Expression): the SQL code string to parse as the table name. + alias (str | Expression): the SQL code string to parse as the table name. If an `Expression` instance is passed, this is used as-is. - as_ (str or Expression): the SQL code string to parse as the table expression. + as_ (str | Expression): the SQL code string to parse as the table expression. If an `Expression` instance is passed, it will be used as-is. recursive (bool): set the RECURSIVE part of the expression. Defaults to `False`. append (bool): if `True`, add to any existing expressions. @@ -1295,7 +1332,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `From` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `From`. append (bool): if `True`, add to any existing expressions. @@ -1328,7 +1365,7 @@ class Select(Subqueryable): 'SELECT x, COUNT(1) FROM tbl GROUP BY x' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Group`. If nothing is passed in then a group by is not applied to the expression @@ -1364,7 +1401,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl ORDER BY x DESC' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Order`. append (bool): if `True`, add to any existing expressions. @@ -1397,7 +1434,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl SORT BY x DESC' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `SORT`. append (bool): if `True`, add to any existing expressions. @@ -1430,7 +1467,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl CLUSTER BY x DESC' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If a `Group` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Cluster`. append (bool): if `True`, add to any existing expressions. @@ -1463,7 +1500,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl LIMIT 10' Args: - expression (str or int or Expression): the SQL code string to parse. + expression (str | int | Expression): the SQL code string to parse. This can also be an integer. If a `Limit` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Limit`. @@ -1494,7 +1531,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl OFFSET 10' Args: - expression (str or int or Expression): the SQL code string to parse. + expression (str | int | Expression): the SQL code string to parse. This can also be an integer. If a `Offset` instance is passed, this is used as-is. If another `Expression` instance is passed, it will be wrapped in a `Offset`. @@ -1525,7 +1562,7 @@ class Select(Subqueryable): 'SELECT x, y' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. append (bool): if `True`, add to any existing expressions. Otherwise, this resets the expressions. @@ -1555,7 +1592,7 @@ class Select(Subqueryable): 'SELECT x FROM tbl LATERAL VIEW OUTER EXPLODE(y) tbl2 AS z' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. append (bool): if `True`, add to any existing expressions. Otherwise, this resets the expressions. @@ -1582,6 +1619,7 @@ class Select(Subqueryable): self, expression, on=None, + using=None, append=True, join_type=None, join_alias=None, @@ -1596,15 +1634,20 @@ class Select(Subqueryable): >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y").sql() 'SELECT * FROM tbl JOIN tbl2 ON tbl1.y = tbl2.y' + >>> Select().select("1").from_("a").join("b", using=["x", "y", "z"]).sql() + 'SELECT 1 FROM a JOIN b USING (x, y, z)' + Use `join_type` to change the type of join: >>> Select().select("*").from_("tbl").join("tbl2", on="tbl1.y = tbl2.y", join_type="left outer").sql() 'SELECT * FROM tbl LEFT OUTER JOIN tbl2 ON tbl1.y = tbl2.y' Args: - expression (str or Expression): the SQL code string to parse. + expression (str | Expression): the SQL code string to parse. If an `Expression` instance is passed, it will be used as-is. - on (str or Expression): optionally specify the join criteria as a SQL string. + on (str | Expression): optionally specify the join "on" criteria as a SQL string. + If an `Expression` instance is passed, it will be used as-is. + using (str | Expression): optionally specify the join "using" criteria as a SQL string. If an `Expression` instance is passed, it will be used as-is. append (bool): if `True`, add to any existing expressions. Otherwise, this resets the expressions. @@ -1641,6 +1684,16 @@ class Select(Subqueryable): on = and_(*ensure_list(on), dialect=dialect, **opts) join.set("on", on) + if using: + join = _apply_list_builder( + *ensure_list(using), + instance=join, + arg="using", + append=append, + copy=copy, + **opts, + ) + if join_alias: join.set("this", alias_(join.args["this"], join_alias, table=True)) return _apply_list_builder( @@ -1661,7 +1714,7 @@ class Select(Subqueryable): "SELECT x FROM tbl WHERE x = 'a' OR x < 'b'" Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. append (bool): if `True`, AND the new expressions to any existing expression. @@ -1693,7 +1746,7 @@ class Select(Subqueryable): 'SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 3' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an `Expression` instance is passed, it will be used as-is. Multiple expressions are combined with an AND operator. append (bool): if `True`, AND the new expressions to any existing expression. @@ -1744,7 +1797,7 @@ class Select(Subqueryable): 'CREATE TABLE x AS SELECT * FROM tbl' Args: - table (str or Expression): the SQL code string to parse as the table name. + table (str | Expression): the SQL code string to parse as the table name. If another `Expression` instance is passed, it will be used as-is. properties (dict): an optional mapping of table properties dialect (str): the dialect used to parse the input table. @@ -2620,6 +2673,10 @@ class StrToUnix(Func): arg_types = {"this": True, "format": True} +class NumberToStr(Func): + arg_types = {"this": True, "format": True} + + class Struct(Func): arg_types = {"expressions": True} is_var_len_args = True @@ -2775,7 +2832,7 @@ def maybe_parse( (IDENTIFIER this: x, quoted: False) Args: - sql_or_expression (str or Expression): the SQL code string or an expression + sql_or_expression (str | Expression): the SQL code string or an expression into (Expression): the SQLGlot Expression to parse into dialect (str): the dialect used to parse the input expressions (in the case that an input expression is a SQL string). @@ -2950,9 +3007,9 @@ def union(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo UNION SELECT * FROM bla' Args: - left (str or Expression): the SQL code string corresponding to the left-hand side. + left (str | Expression): the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str or Expression): the SQL code string corresponding to the right-hand side. + right (str | Expression): the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -2975,9 +3032,9 @@ def intersect(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo INTERSECT SELECT * FROM bla' Args: - left (str or Expression): the SQL code string corresponding to the left-hand side. + left (str | Expression): the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str or Expression): the SQL code string corresponding to the right-hand side. + right (str | Expression): the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -3000,9 +3057,9 @@ def except_(left, right, distinct=True, dialect=None, **opts): 'SELECT * FROM foo EXCEPT SELECT * FROM bla' Args: - left (str or Expression): the SQL code string corresponding to the left-hand side. + left (str | Expression): the SQL code string corresponding to the left-hand side. If an `Expression` instance is passed, it will be used as-is. - right (str or Expression): the SQL code string corresponding to the right-hand side. + right (str | Expression): the SQL code string corresponding to the right-hand side. If an `Expression` instance is passed, it will be used as-is. distinct (bool): set the DISTINCT flag if and only if this is true. dialect (str): the dialect used to parse the input expression. @@ -3025,7 +3082,7 @@ def select(*expressions, dialect=None, **opts): 'SELECT col1, col2 FROM tbl' Args: - *expressions (str or Expression): the SQL code string to parse as the expressions of a + *expressions (str | Expression): the SQL code string to parse as the expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expressions (in the case that an input expression is a SQL string). @@ -3047,7 +3104,7 @@ def from_(*expressions, dialect=None, **opts): 'SELECT col1, col2 FROM tbl' Args: - *expressions (str or Expression): the SQL code string to parse as the FROM expressions of a + *expressions (str | Expression): the SQL code string to parse as the FROM expressions of a SELECT statement. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression (in the case that the input expression is a SQL string). @@ -3132,7 +3189,7 @@ def condition(expression, dialect=None, **opts): 'SELECT * FROM tbl WHERE x = 1 AND y = 1' Args: - *expression (str or Expression): the SQL code string to parse. + *expression (str | Expression): the SQL code string to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression (in the case that the input expression is a SQL string). @@ -3159,7 +3216,7 @@ def and_(*expressions, dialect=None, **opts): 'x = 1 AND (y = 1 AND z = 1)' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3179,7 +3236,7 @@ def or_(*expressions, dialect=None, **opts): 'x = 1 OR (y = 1 OR z = 1)' Args: - *expressions (str or Expression): the SQL code strings to parse. + *expressions (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3199,7 +3256,7 @@ def not_(expression, dialect=None, **opts): "NOT this_suit = 'black'" Args: - expression (str or Expression): the SQL code strings to parse. + expression (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3283,9 +3340,9 @@ def alias_(expression, alias, table=False, dialect=None, quoted=None, **opts): 'foo AS bar' Args: - expression (str or Expression): the SQL code strings to parse. + expression (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str or Identifier): the alias name to use. If the name has + alias (str | Identifier): the alias name to use. If the name has special characters it is quoted. table (bool): create a table alias, default false dialect (str): the dialect used to parse the input expression. @@ -3322,9 +3379,9 @@ def subquery(expression, alias=None, dialect=None, **opts): 'SELECT x FROM (SELECT x FROM tbl) AS bar' Args: - expression (str or Expression): the SQL code strings to parse. + expression (str | Expression): the SQL code strings to parse. If an Expression instance is passed, this is used as-is. - alias (str or Expression): the alias name to use. + alias (str | Expression): the alias name to use. dialect (str): the dialect used to parse the input expression. **opts: other options to use to parse the input expressions. @@ -3340,8 +3397,8 @@ def column(col, table=None, quoted=None): """ Build a Column. Args: - col (str or Expression): column name - table (str or Expression): table name + col (str | Expression): column name + table (str | Expression): table name Returns: Column: column instance """ @@ -3355,9 +3412,9 @@ def table_(table, db=None, catalog=None, quoted=None, alias=None): """Build a Table. Args: - table (str or Expression): column name - db (str or Expression): db name - catalog (str or Expression): catalog name + table (str | Expression): column name + db (str | Expression): db name + catalog (str | Expression): catalog name Returns: Table: table instance @@ -3423,7 +3480,7 @@ def convert(value): values=[convert(v) for v in value.values()], ) if isinstance(value, datetime.datetime): - datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S")) + datetime_literal = Literal.string(value.strftime("%Y-%m-%d %H:%M:%S.%f%z")) return TimeStrToTime(this=datetime_literal) if isinstance(value, datetime.date): date_literal = Literal.string(value.strftime("%Y-%m-%d")) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 1784287..ca14425 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -65,6 +65,9 @@ class Generator: exp.VolatilityProperty: lambda self, e: self.sql(e.name), } + # whether 'CREATE ... TRANSIENT ... TABLE' is allowed + # can override in dialects + CREATE_TRANSIENT = False # whether or not null ordering is supported in order by NULL_ORDERING_SUPPORTED = True # always do union distinct or union all @@ -368,15 +371,14 @@ class Generator: expression_sql = self.sql(expression, "expression") expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" + transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" replace = " OR REPLACE" if expression.args.get("replace") else "" exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else "" unique = " UNIQUE" if expression.args.get("unique") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" properties = self.sql(expression, "properties") - expression_sql = ( - f"CREATE{replace}{temporary}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" - ) + expression_sql = f"CREATE{replace}{temporary}{transient}{unique}{materialized} {kind}{exists_sql} {this}{properties} {expression_sql}" return self.prepend_ctes(expression, expression_sql) def describe_sql(self, expression): @@ -716,15 +718,21 @@ class Generator: def lateral_sql(self, expression): this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subquery): - return f"LATERAL{self.sep()}{this}" - op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") + return f"LATERAL {this}" + alias = expression.args["alias"] table = alias.name table = f" {table}" if table else table columns = self.expressions(alias, key="columns", flat=True) columns = f" AS {columns}" if columns else "" - return f"{op_sql}{self.sep()}{this}{table}{columns}" + + if expression.args.get("view"): + op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") + return f"{op_sql}{self.sep()}{this}{table}{columns}" + + return f"LATERAL {this}{table}{columns}" def limit_sql(self, expression): this = self.sql(expression, "this") diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py index 36ba028..ebee92a 100644 --- a/sqlglot/optimizer/qualify_columns.py +++ b/sqlglot/optimizer/qualify_columns.py @@ -211,21 +211,26 @@ def _qualify_columns(scope, resolver): if column_table: column.set("table", exp.to_identifier(column_table)) + columns_missing_from_scope = [] # Determine whether each reference in the order by clause is to a column or an alias. for ordered in scope.find_all(exp.Ordered): for column in ordered.find_all(exp.Column): - column_table = column.table - column_name = column.name + if not column.table and column.parent is not ordered and column.name in resolver.all_columns: + columns_missing_from_scope.append(column) - if column_table or column.parent is ordered or column_name not in resolver.all_columns: - continue + # Determine whether each reference in the having clause is to a column or an alias. + for having in scope.find_all(exp.Having): + for column in having.find_all(exp.Column): + if not column.table and column.find_ancestor(exp.AggFunc) and column.name in resolver.all_columns: + columns_missing_from_scope.append(column) - column_table = resolver.get_table(column_name) + for column in columns_missing_from_scope: + column_table = resolver.get_table(column.name) - if column_table is None: - raise OptimizeError(f"Ambiguous column: {column_name}") + if column_table is None: + raise OptimizeError(f"Ambiguous column: {column.name}") - column.set("table", exp.to_identifier(column_table)) + column.set("table", exp.to_identifier(column_table)) def _expand_stars(scope, resolver): diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index b7eb6c2..5a75ee2 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -232,7 +232,7 @@ class Scope: self._columns = [] for column in columns + external_columns: - ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Hint) + ancestor = column.find_ancestor(exp.Qualify, exp.Order, exp.Having, exp.Hint) if ( not ancestor or column.table diff --git a/sqlglot/parser.py b/sqlglot/parser.py index b94313a..79a1d90 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -131,6 +131,7 @@ class Parser: TokenType.ALTER, TokenType.ALWAYS, TokenType.ANTI, + TokenType.APPLY, TokenType.BEGIN, TokenType.BOTH, TokenType.BUCKET, @@ -190,6 +191,7 @@ class Parser: TokenType.TABLE, TokenType.TABLE_FORMAT, TokenType.TEMPORARY, + TokenType.TRANSIENT, TokenType.TOP, TokenType.TRAILING, TokenType.TRUNCATE, @@ -204,7 +206,7 @@ class Parser: *TYPE_TOKENS, } - TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL} + TABLE_ALIAS_TOKENS = ID_VAR_TOKENS - {TokenType.NATURAL, TokenType.APPLY} TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH} @@ -685,6 +687,7 @@ class Parser: def _parse_create(self): replace = self._match(TokenType.OR) and self._match(TokenType.REPLACE) temporary = self._match(TokenType.TEMPORARY) + transient = self._match(TokenType.TRANSIENT) unique = self._match(TokenType.UNIQUE) materialized = self._match(TokenType.MATERIALIZED) @@ -723,6 +726,7 @@ class Parser: exists=exists, properties=properties, temporary=temporary, + transient=transient, replace=replace, unique=unique, materialized=materialized, @@ -1057,8 +1061,8 @@ class Parser: return self._parse_set_operations(this) if this else None - def _parse_with(self): - if not self._match(TokenType.WITH): + def _parse_with(self, skip_with_token=False): + if not skip_with_token and not self._match(TokenType.WITH): return None recursive = self._match(TokenType.RECURSIVE) @@ -1167,28 +1171,53 @@ class Parser: return self.expression(exp.From, expressions=self._parse_csv(self._parse_table)) def _parse_lateral(self): - if not self._match(TokenType.LATERAL): + outer_apply = self._match_pair(TokenType.OUTER, TokenType.APPLY) + cross_apply = self._match_pair(TokenType.CROSS, TokenType.APPLY) + + if outer_apply or cross_apply: + this = self._parse_select(table=True) + view = None + outer = not cross_apply + elif self._match(TokenType.LATERAL): + this = self._parse_select(table=True) + view = self._match(TokenType.VIEW) + outer = self._match(TokenType.OUTER) + else: return None - subquery = self._parse_select(table=True) + if not this: + this = self._parse_function() - if subquery: - return self.expression(exp.Lateral, this=subquery) + table_alias = self._parse_id_var(any_token=False) - self._match(TokenType.VIEW) - outer = self._match(TokenType.OUTER) + columns = None + if self._match(TokenType.ALIAS): + columns = self._parse_csv(self._parse_id_var) + elif self._match(TokenType.L_PAREN): + columns = self._parse_csv(self._parse_id_var) + self._match(TokenType.R_PAREN) - return self.expression( + expression = self.expression( exp.Lateral, - this=self._parse_function(), + this=this, + view=view, outer=outer, alias=self.expression( exp.TableAlias, - this=self._parse_id_var(any_token=False), - columns=(self._parse_csv(self._parse_id_var) if self._match(TokenType.ALIAS) else None), + this=table_alias, + columns=columns, ), ) + if outer_apply or cross_apply: + return self.expression( + exp.Join, + this=expression, + side=None if cross_apply else "LEFT", + ) + + return expression + def _parse_join_side_and_kind(self): return ( self._match(TokenType.NATURAL) and self._prev, @@ -1196,10 +1225,10 @@ class Parser: self._match_set(self.JOIN_KINDS) and self._prev, ) - def _parse_join(self): + def _parse_join(self, skip_join_token=False): natural, side, kind = self._parse_join_side_and_kind() - if not self._match(TokenType.JOIN): + if not skip_join_token and not self._match(TokenType.JOIN): return None kwargs = {"this": self._parse_table()} @@ -1425,13 +1454,13 @@ class Parser: unpivot=unpivot, ) - def _parse_where(self): - if not self._match(TokenType.WHERE): + def _parse_where(self, skip_where_token=False): + if not skip_where_token and not self._match(TokenType.WHERE): return None return self.expression(exp.Where, this=self._parse_conjunction()) - def _parse_group(self): - if not self._match(TokenType.GROUP_BY): + def _parse_group(self, skip_group_by_token=False): + if not skip_group_by_token and not self._match(TokenType.GROUP_BY): return None return self.expression( exp.Group, @@ -1457,8 +1486,8 @@ class Parser: return self.expression(exp.Tuple, expressions=grouping_set) return self._parse_id_var() - def _parse_having(self): - if not self._match(TokenType.HAVING): + def _parse_having(self, skip_having_token=False): + if not skip_having_token and not self._match(TokenType.HAVING): return None return self.expression(exp.Having, this=self._parse_conjunction()) @@ -1467,8 +1496,8 @@ class Parser: return None return self.expression(exp.Qualify, this=self._parse_conjunction()) - def _parse_order(self, this=None): - if not self._match(TokenType.ORDER_BY): + def _parse_order(self, this=None, skip_order_token=False): + if not skip_order_token and not self._match(TokenType.ORDER_BY): return this return self.expression(exp.Order, this=this, expressions=self._parse_csv(self._parse_ordered)) @@ -1502,7 +1531,11 @@ class Parser: def _parse_limit(self, this=None, top=False): if self._match(TokenType.TOP if top else TokenType.LIMIT): - return self.expression(exp.Limit, this=this, expression=self._parse_number()) + limit_paren = self._match(TokenType.L_PAREN) + limit_exp = self.expression(exp.Limit, this=this, expression=self._parse_number()) + if limit_paren: + self._match(TokenType.R_PAREN) + return limit_exp if self._match(TokenType.FETCH): direction = self._match_set((TokenType.FIRST, TokenType.NEXT)) direction = self._prev.text if direction else "FIRST" @@ -2136,7 +2169,7 @@ class Parser: return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to) def _parse_convert(self, strict): - this = self._parse_field() + this = self._parse_column() if self._match(TokenType.USING): to = self.expression(exp.CharacterSet, this=self._parse_var()) elif self._match(TokenType.COMMA): diff --git a/sqlglot/time.py b/sqlglot/time.py index de28ac0..729b50d 100644 --- a/sqlglot/time.py +++ b/sqlglot/time.py @@ -43,5 +43,4 @@ def format_time(string, mapping, trie=None): if result and end > size: chunks.append(chars) - return "".join(mapping.get(chars, chars) for chars in chunks) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 1a9d72e..766c01a 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -107,6 +107,7 @@ class TokenType(AutoName): ANALYZE = auto() ANTI = auto() ANY = auto() + APPLY = auto() ARRAY = auto() ASC = auto() AT_TIME_ZONE = auto() @@ -256,6 +257,7 @@ class TokenType(AutoName): TABLE_FORMAT = auto() TABLE_SAMPLE = auto() TEMPORARY = auto() + TRANSIENT = auto() TOP = auto() THEN = auto() TRUE = auto() @@ -560,6 +562,7 @@ class Tokenizer(metaclass=_Tokenizer): "TABLESAMPLE": TokenType.TABLE_SAMPLE, "TEMP": TokenType.TEMPORARY, "TEMPORARY": TokenType.TEMPORARY, + "TRANSIENT": TokenType.TRANSIENT, "THEN": TokenType.THEN, "TRUE": TokenType.TRUE, "TRAILING": TokenType.TRAILING, @@ -582,6 +585,7 @@ class Tokenizer(metaclass=_Tokenizer): "WITH LOCAL TIME ZONE": TokenType.WITH_LOCAL_TIME_ZONE, "WITHIN GROUP": TokenType.WITHIN_GROUP, "WITHOUT TIME ZONE": TokenType.WITHOUT_TIME_ZONE, + "APPLY": TokenType.APPLY, "ARRAY": TokenType.ARRAY, "BOOL": TokenType.BOOLEAN, "BOOLEAN": TokenType.BOOLEAN, diff --git a/tests/dataframe/integration/dataframe_validator.py b/tests/dataframe/integration/dataframe_validator.py index 6c4642f..4a89c78 100644 --- a/tests/dataframe/integration/dataframe_validator.py +++ b/tests/dataframe/integration/dataframe_validator.py @@ -1,3 +1,4 @@ +import sys import typing as t import unittest import warnings @@ -9,7 +10,9 @@ if t.TYPE_CHECKING: from pyspark.sql import DataFrame as SparkDataFrame -@unittest.skipIf(SKIP_INTEGRATION, "Skipping Integration Tests since `SKIP_INTEGRATION` is set") +@unittest.skipIf( + SKIP_INTEGRATION or sys.version_info[:2] > (3, 10), "Skipping Integration Tests since `SKIP_INTEGRATION` is set" +) class DataFrameValidator(unittest.TestCase): spark = None sqlglot = None diff --git a/tests/dataframe/unit/test_column.py b/tests/dataframe/unit/test_column.py index df0ebff..977971e 100644 --- a/tests/dataframe/unit/test_column.py +++ b/tests/dataframe/unit/test_column.py @@ -146,7 +146,8 @@ class TestDataframeColumn(unittest.TestCase): F.col("cola").between(datetime.date(2022, 1, 1), datetime.date(2022, 3, 1)).sql(), ) self.assertEqual( - "cola BETWEEN CAST('2022-01-01 01:01:01' AS TIMESTAMP) " "AND CAST('2022-03-01 01:01:01' AS TIMESTAMP)", + "cola BETWEEN CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP) " + "AND CAST('2022-03-01 01:01:01.000000' AS TIMESTAMP)", F.col("cola").between(datetime.datetime(2022, 1, 1, 1, 1, 1), datetime.datetime(2022, 3, 1, 1, 1, 1)).sql(), ) diff --git a/tests/dataframe/unit/test_functions.py b/tests/dataframe/unit/test_functions.py index 97753bd..eadbb93 100644 --- a/tests/dataframe/unit/test_functions.py +++ b/tests/dataframe/unit/test_functions.py @@ -30,7 +30,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.lit(datetime.date(2022, 1, 1)) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) test_datetime = SF.lit(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.lit({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) @@ -52,7 +52,7 @@ class TestFunctions(unittest.TestCase): test_date = SF.col(datetime.date(2022, 1, 1)) self.assertEqual("TO_DATE('2022-01-01')", test_date.sql()) test_datetime = SF.col(datetime.datetime(2022, 1, 1, 1, 1, 1)) - self.assertEqual("CAST('2022-01-01 01:01:01' AS TIMESTAMP)", test_datetime.sql()) + self.assertEqual("CAST('2022-01-01 01:01:01.000000' AS TIMESTAMP)", test_datetime.sql()) test_dict = SF.col({"cola": 1, "colb": "test"}) self.assertEqual("STRUCT(1 AS cola, 'test' AS colb)", test_dict.sql()) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py new file mode 100644 index 0000000..e242e73 --- /dev/null +++ b/tests/dialects/test_databricks.py @@ -0,0 +1,33 @@ +from tests.dialects.test_dialect import Validator + + +class TestDatabricks(Validator): + dialect = "databricks" + + def test_datediff(self): + self.validate_all( + "SELECT DATEDIFF(year, 'start', 'end')", + write={ + "tsql": "SELECT DATEDIFF(year, 'start', 'end')", + "databricks": "SELECT DATEDIFF(year, 'start', 'end')", + }, + ) + + def test_add_date(self): + self.validate_all( + "SELECT DATEADD(year, 1, '2020-01-01')", + write={ + "tsql": "SELECT DATEADD(year, 1, '2020-01-01')", + "databricks": "SELECT DATEADD(year, 1, '2020-01-01')", + }, + ) + self.validate_all( + "SELECT DATEDIFF('end', 'start')", write={"databricks": "SELECT DATEDIFF(DAY, 'start', 'end')"} + ) + self.validate_all( + "SELECT DATE_ADD('2020-01-01', 1)", + write={ + "tsql": "SELECT DATEADD(DAY, 1, '2020-01-01')", + "databricks": "SELECT DATEADD(DAY, 1, '2020-01-01')", + }, + ) diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 5d1cf13..3b837df 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -81,6 +81,28 @@ class TestDialect(Validator): "starrocks": "CAST(a AS STRING)", }, ) + self.validate_all( + "CAST(a AS BINARY(4))", + read={ + "presto": "CAST(a AS VARBINARY(4))", + "sqlite": "CAST(a AS VARBINARY(4))", + }, + write={ + "bigquery": "CAST(a AS BINARY(4))", + "clickhouse": "CAST(a AS BINARY(4))", + "duckdb": "CAST(a AS BINARY(4))", + "mysql": "CAST(a AS BINARY(4))", + "hive": "CAST(a AS BINARY(4))", + "oracle": "CAST(a AS BLOB(4))", + "postgres": "CAST(a AS BYTEA(4))", + "presto": "CAST(a AS VARBINARY(4))", + "redshift": "CAST(a AS VARBYTE(4))", + "snowflake": "CAST(a AS BINARY(4))", + "sqlite": "CAST(a AS BLOB(4))", + "spark": "CAST(a AS BINARY(4))", + "starrocks": "CAST(a AS BINARY(4))", + }, + ) self.validate_all( "CAST(MAP('a', '1') AS MAP(TEXT, TEXT))", write={ diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 159b643..fea2311 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -293,6 +293,18 @@ class TestSnowflake(Validator): "CREATE TABLE a (x DATE, y BIGINT) WITH (PARTITION BY (x), integration='q', auto_refresh=TRUE, file_format=(type = parquet))" ) self.validate_identity("CREATE MATERIALIZED VIEW a COMMENT='...' AS SELECT 1 FROM x") + self.validate_all( + "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + read={ + "postgres": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + "snowflake": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + }, + write={ + "postgres": "CREATE OR REPLACE TABLE a (id INT)", + "mysql": "CREATE OR REPLACE TABLE a (id INT)", + "snowflake": "CREATE OR REPLACE TRANSIENT TABLE a (id INT)", + }, + ) def test_user_defined_functions(self): self.validate_all( diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 2a20163..d22a9c2 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -260,6 +260,20 @@ class TestTSQL(Validator): "spark": "CAST(x AS INT)", }, ) + self.validate_all( + "SELECT CONVERT(VARCHAR(10), testdb.dbo.test.x, 120) y FROM testdb.dbo.test", + write={ + "mysql": "SELECT CAST(TIME_TO_STR(testdb.dbo.test.x, '%Y-%m-%d %H:%M:%S') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + "spark": "SELECT CAST(DATE_FORMAT(testdb.dbo.test.x, 'yyyy-MM-dd HH:mm:ss') AS VARCHAR(10)) AS y FROM testdb.dbo.test", + }, + ) + self.validate_all( + "SELECT CONVERT(VARCHAR(10), y.x) z FROM testdb.dbo.test y", + write={ + "mysql": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", + "spark": "SELECT CAST(y.x AS VARCHAR(10)) AS z FROM testdb.dbo.test AS y", + }, + ) def test_add_date(self): self.validate_identity("SELECT DATEADD(year, 1, '2017/08/25')") @@ -267,7 +281,10 @@ class TestTSQL(Validator): "SELECT DATEADD(year, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 12)"} ) self.validate_all("SELECT DATEADD(qq, 1, '2017/08/25')", write={"spark": "SELECT ADD_MONTHS('2017/08/25', 3)"}) - self.validate_all("SELECT DATEADD(wk, 1, '2017/08/25')", write={"spark": "SELECT DATE_ADD('2017/08/25', 7)"}) + self.validate_all( + "SELECT DATEADD(wk, 1, '2017/08/25')", + write={"spark": "SELECT DATE_ADD('2017/08/25', 7)", "databricks": "SELECT DATEADD(week, 1, '2017/08/25')"}, + ) def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(year, '2020/01/01', '2021/01/01')") @@ -279,11 +296,19 @@ class TestTSQL(Validator): }, ) self.validate_all( - "SELECT DATEDIFF(month, 'start','end')", - write={"spark": "SELECT MONTHS_BETWEEN('end', 'start')", "tsql": "SELECT DATEDIFF(month, 'start', 'end')"}, + "SELECT DATEDIFF(mm, 'start','end')", + write={ + "spark": "SELECT MONTHS_BETWEEN('end', 'start')", + "tsql": "SELECT DATEDIFF(month, 'start', 'end')", + "databricks": "SELECT DATEDIFF(month, 'start', 'end')", + }, ) self.validate_all( - "SELECT DATEDIFF(quarter, 'start', 'end')", write={"spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3"} + "SELECT DATEDIFF(quarter, 'start', 'end')", + write={ + "spark": "SELECT MONTHS_BETWEEN('end', 'start') / 3", + "databricks": "SELECT DATEDIFF(quarter, 'start', 'end')", + }, ) def test_iif(self): @@ -294,3 +319,64 @@ class TestTSQL(Validator): "spark": "SELECT IF(cond, 'True', 'False')", }, ) + + def test_lateral_subquery(self): + self.validate_all( + "SELECT x.a, x.b, t.v, t.y FROM x CROSS APPLY (SELECT v, y FROM t) t(v, y)", + write={ + "spark": "SELECT x.a, x.b, t.v, t.y FROM x JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", + }, + ) + self.validate_all( + "SELECT x.a, x.b, t.v, t.y FROM x OUTER APPLY (SELECT v, y FROM t) t(v, y)", + write={ + "spark": "SELECT x.a, x.b, t.v, t.y FROM x LEFT JOIN LATERAL (SELECT v, y FROM t) AS t(v, y)", + }, + ) + + def test_lateral_table_valued_function(self): + self.validate_all( + "SELECT t.x, y.z FROM x CROSS APPLY tvfTest(t.x)y(z)", + write={ + "spark": "SELECT t.x, y.z FROM x JOIN LATERAL TVFTEST(t.x) y AS z", + }, + ) + self.validate_all( + "SELECT t.x, y.z FROM x OUTER APPLY tvfTest(t.x)y(z)", + write={ + "spark": "SELECT t.x, y.z FROM x LEFT JOIN LATERAL TVFTEST(t.x) y AS z", + }, + ) + + def test_top(self): + self.validate_all( + "SELECT TOP 3 * FROM A", + write={ + "spark": "SELECT * FROM A LIMIT 3", + }, + ) + self.validate_all( + "SELECT TOP (3) * FROM A", + write={ + "spark": "SELECT * FROM A LIMIT 3", + }, + ) + + def test_format(self): + self.validate_identity("SELECT FORMAT('01-01-1991', 'd.mm.yyyy')") + self.validate_identity("SELECT FORMAT(12345, '###.###.###')") + self.validate_identity("SELECT FORMAT(1234567, 'f')") + self.validate_all( + "SELECT FORMAT(1000000.01,'###,###.###')", + write={"spark": "SELECT FORMAT_NUMBER(1000000.01, '###,###.###')"}, + ) + self.validate_all("SELECT FORMAT(1234567, 'f')", write={"spark": "SELECT FORMAT_NUMBER(1234567, 'f')"}) + self.validate_all( + "SELECT FORMAT('01-01-1991', 'dd.mm.yyyy')", + write={"spark": "SELECT DATE_FORMAT('01-01-1991', 'dd.mm.yyyy')"}, + ) + self.validate_all( + "SELECT FORMAT(date_col, 'dd.mm.yyyy')", write={"spark": "SELECT DATE_FORMAT(date_col, 'dd.mm.yyyy')"} + ) + self.validate_all("SELECT FORMAT(date_col, 'm')", write={"spark": "SELECT DATE_FORMAT(date_col, 'MMMM d')"}) + self.validate_all("SELECT FORMAT(num_col, 'c')", write={"spark": "SELECT FORMAT_NUMBER(num_col, 'c')"}) diff --git a/tests/fixtures/optimizer/qualify_columns.sql b/tests/fixtures/optimizer/qualify_columns.sql index 858f232..a958c08 100644 --- a/tests/fixtures/optimizer/qualify_columns.sql +++ b/tests/fixtures/optimizer/qualify_columns.sql @@ -31,6 +31,23 @@ SELECT x.a + x.b AS "_col_0" FROM x AS x; SELECT a, SUM(b) FROM x WHERE a > 1 AND b > 1 GROUP BY a; SELECT x.a AS a, SUM(x.b) AS "_col_1" FROM x AS x WHERE x.a > 1 AND x.b > 1 GROUP BY x.a; +SELECT SUM(a) AS c FROM x HAVING SUM(a) > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(x.a) > 3; + +SELECT SUM(a) AS a FROM x HAVING SUM(a) > 3; +SELECT SUM(x.a) AS a FROM x AS x HAVING SUM(x.a) > 3; + +SELECT SUM(a) AS c FROM x HAVING c > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING c > 3; + +# execute: false +SELECT SUM(a) AS a FROM x HAVING a > 3; +SELECT SUM(x.a) AS a FROM x AS x HAVING a > 3; + +# execute: false +SELECT SUM(a) AS c FROM x HAVING SUM(c) > 3; +SELECT SUM(x.a) AS c FROM x AS x HAVING SUM(c) > 3; + SELECT a AS j, b FROM x ORDER BY j; SELECT x.a AS j, x.b AS b FROM x AS x ORDER BY j; diff --git a/tests/test_build.py b/tests/test_build.py index f51996d..b7b6865 100644 --- a/tests/test_build.py +++ b/tests/test_build.py @@ -153,6 +153,42 @@ class TestBuild(unittest.TestCase): ), "SELECT x FROM tbl LEFT JOIN (SELECT b FROM tbl2) AS aliased ON a = b", ), + ( + lambda: select("x", "y", "z") + .from_("merged_df") + .join("vte_diagnosis_df", using=["patient_id", "encounter_id"]), + "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", + ), + ( + lambda: select("x", "y", "z") + .from_("merged_df") + .join("vte_diagnosis_df", using=[exp.to_identifier("patient_id"), exp.to_identifier("encounter_id")]), + "SELECT x, y, z FROM merged_df JOIN vte_diagnosis_df USING (patient_id, encounter_id)", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).on("y = 1", "z = 1"), + "JOIN x ON y = 1 AND z = 1", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).on("y = 1"), + "JOIN x ON y = 1", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).using("bar", "bob"), + "JOIN x USING (bar, bob)", + ), + ( + lambda: parse_one("JOIN x", into=exp.Join).using("bar"), + "JOIN x USING (bar)", + ), + ( + lambda: select("x").from_("foo").join("bla", using="bob"), + "SELECT x FROM foo JOIN bla USING (bob)", + ), + ( + lambda: select("x").from_("foo").join("bla", using="bob"), + "SELECT x FROM foo JOIN bla USING (bob)", + ), ( lambda: select("x", "COUNT(y)").from_("tbl").group_by("x").having("COUNT(y) > 0"), "SELECT x, COUNT(y) FROM tbl GROUP BY x HAVING COUNT(y) > 0", diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 9af59d9..adfd329 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -538,7 +538,11 @@ class TestExpressions(unittest.TestCase): ((1, "2", None), "(1, '2', NULL)"), ([1, "2", None], "ARRAY(1, '2', NULL)"), ({"x": None}, "MAP('x', NULL)"), - (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01')"), + (datetime.datetime(2022, 10, 1, 1, 1, 1), "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000')"), + ( + datetime.datetime(2022, 10, 1, 1, 1, 1, tzinfo=datetime.timezone.utc), + "TIME_STR_TO_TIME('2022-10-01 01:01:01.000000+0000')", + ), (datetime.date(2022, 10, 1), "DATE_STR_TO_DATE('2022-10-01')"), ]: with self.subTest(value): -- cgit v1.2.3