diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-12-24 07:49:56 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2023-12-24 07:49:56 +0000 |
commit | 97144fb2271b9bd749deac2eb3537103d9513182 (patch) | |
tree | ebd8aa8769becb97f7477cf9c14889146e6b8614 | |
parent | Adding upstream version 20.3.0. (diff) | |
download | sqlglot-upstream/20.4.0.tar.xz sqlglot-upstream/20.4.0.zip |
Adding upstream version 20.4.0.upstream/20.4.0
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r-- | .github/workflows/python-publish.yml | 19 | ||||
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | posts/ast_primer.md | 405 | ||||
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 1 | ||||
-rw-r--r-- | sqlglot/expressions.py | 11 | ||||
-rw-r--r-- | sqlglot/generator.py | 17 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 55 | ||||
-rw-r--r-- | sqlglot/parser.py | 1 | ||||
-rw-r--r-- | sqlglotrs/src/lib.rs | 13 | ||||
-rw-r--r-- | tests/dialects/test_bigquery.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_clickhouse.py | 1 | ||||
-rw-r--r-- | tests/dialects/test_dialect.py | 37 | ||||
-rw-r--r-- | tests/dialects/test_presto.py | 16 | ||||
-rw-r--r-- | tests/fixtures/optimizer/simplify.sql | 14 |
20 files changed, 562 insertions, 52 deletions
diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 54d79f4..fd418b6 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -8,9 +8,6 @@ on: permissions: contents: read -env: - python_interpreters: 3.7 3.8 3.9 3.10 3.11 - jobs: build-rs: strategy: @@ -29,6 +26,9 @@ jobs: - os: windows target: i686 python-architecture: x86 + exclude: + - os: windows + target: aarch64 runs-on: ${{ (matrix.os == 'linux' && 'ubuntu') || matrix.os }}-latest steps: - uses: actions/checkout@v3 @@ -37,36 +37,34 @@ jobs: python-version: '3.10' architecture: ${{ matrix.python-architecture || 'x64' }} - name: Build wheels - working-directory: ./sqlglotrs uses: PyO3/maturin-action@v1 with: target: ${{ matrix.target }} - args: --release --out dist --interpreter $python_interpreters + args: --release --out dist --interpreter 3.7 3.8 3.9 3.10 3.11 3.12 sccache: 'true' manylinux: auto + working-directory: ./sqlglotrs - name: Upload wheels - working-directory: ./sqlglotrs uses: actions/upload-artifact@v3 with: name: wheels - path: dist + path: sqlglotrs/dist sdist-rs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Build sdist - working-directory: ./sqlglotrs uses: PyO3/maturin-action@v1 with: command: sdist args: --out dist + working-directory: ./sqlglotrs - name: Upload sdist - working-directory: ./sqlglotrs uses: actions/upload-artifact@v3 with: name: wheels - path: dist + path: sqlglotrs/dist deploy-rs: runs-on: ubuntu-latest @@ -76,7 +74,6 @@ jobs: with: name: wheels - name: Publish to PyPI - working-directory: ./sqlglotrs uses: PyO3/maturin-action@v1 env: MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} @@ -4,11 +4,11 @@ SQLGlot is a no-dependency SQL parser, transpiler, optimizer, and engine. It can It is a very comprehensive generic SQL parser with a robust [test suite](https://github.com/tobymao/sqlglot/blob/main/tests/). It is also quite [performant](#benchmarks), while being written purely in Python. -You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. +You can easily [customize](#custom-dialects) the parser, [analyze](#metadata) queries, traverse expression trees, and programmatically [build](#build-and-modify-sql) SQL. Syntax [errors](#parser-errors) are highlighted and dialect incompatibilities can warn or raise depending on configurations. However, it should be noted that SQL validation is not SQLGlot’s goal, so some syntax errors may go unnoticed. -Learn more about the SQLGlot API in the [documentation](https://sqlglot.com/). +Learn more about SQLGlot in the API [documentation](https://sqlglot.com/) and the expression tree [primer](https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md). Contributions are very welcome in SQLGlot; read the [contribution guide](https://github.com/tobymao/sqlglot/blob/main/CONTRIBUTING.md) to get started! @@ -176,6 +176,8 @@ for table in parse_one("SELECT * FROM x JOIN y JOIN z").find_all(exp.Table): print(table.name) ``` +Read the [ast primer](https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md) to learn more about SQLGlot's internals. + ### Parser Errors When the parser detects an error in the syntax, it raises a ParseError: diff --git a/posts/ast_primer.md b/posts/ast_primer.md new file mode 100644 index 0000000..a02c11f --- /dev/null +++ b/posts/ast_primer.md @@ -0,0 +1,405 @@ +# A Primer on SQLGlot's Abstract Syntax Tree + +SQLGlot is a powerful tool for analyzing and transforming SQL, but the learning curve can be intimidating. + +This post is intended to familiarize newbies with SQLGlot's abstract syntax trees, how to traverse them, and how to mutate them. + +## The tree + +SQLGlot parses SQL into an abstract syntax tree (AST). +```python +from sqlglot import parse_one + +ast = parse_one("SELECT a FROM (SELECT a FROM x) AS x") +``` + +An AST is a data structure that represents a SQL statement. The best way to glean the structure of a particular AST is python's builtin `repr` function: +```python +repr(ast) + +# (SELECT expressions: +# (COLUMN this: +# (IDENTIFIER this: a, quoted: False)), from: +# (FROM this: +# (SUBQUERY this: +# (SELECT expressions: +# (COLUMN this: +# (IDENTIFIER this: a, quoted: False)), from: +# (FROM this: +# (TABLE this: +# (IDENTIFIER this: x, quoted: False)))), alias: +# (TABLEALIAS this: +# (IDENTIFIER this: x, quoted: False))))) +``` + +This is a textual representation of the internal data structure. Here's a breakdown of some of its components: +``` +Expression type child key + | / +(SELECT expressions: + (COLUMN this: ---------------------------------- COLUMN is a child node of SELECT + (IDENTIFIER this: a, quoted: False)), from: -- "from:" is another child key of SELECT + (FROM this: ------------------------------------- FROM is also a child node of SELECT + ... +``` + +## Nodes of the tree + +The nodes in this tree are instances of `sqlglot.Expression`. Nodes reference their children in `args` and their parent in `parent`: +```python +ast.args +# { +# "expressions": [(COLUMN this: ...)], +# "from": (FROM this: ...), +# ... +# } + +ast.args["expressions"][0] +# (COLUMN this: ...) + +ast.args["expressions"][0].args["this"] +# (IDENTIFIER this: ...) + +ast.args["from"] +# (FROM this: ...) + +assert ast.args["expressions"][0].args["this"].parent.parent is ast +``` + +Children can either be: +1. An Expression instance +2. A list of Expression instances +3. Another Python object, such as str or bool. This will always be a leaf node in the tree. + +Navigating this tree requires an understanding of the different Expression types. The best way to browse Expression types is directly in the code at [expressions.py](../sqlglot/expressions.py). Let's look at a simplified version of one Expression type: +```python +class Column(Expression): + arg_types = { + "this": True, + "table": False, + ... + } +``` + +`Column` subclasses `Expression`. + +`arg_types` is a class attribute that specifies the possible children. The `args` keys of an Expression instance correspond to the `arg_types` keys of its class. The values of the `arg_types` dict are `True` if the key is required. + +There are some common `arg_types` keys: +- "this": This is typically used for the primary child. In `Column`, "this" is the identifier for the column's name. +- "expression": This is typically used for the secondary child +- "expressions": This is typically used for a primary list of children + +There aren't strict rules for when these keys are used, but they help with some of the convenience methods available on all Expression types: +- `Expression.this`: shorthand for `self.args.get("this")` +- `Expression.expression`: similarly, shorthand for the expression arg +- `Expression.expressions`: similarly, shorthand for the expressions list arg +- `Expression.name`: text name for whatever `this` is + +`arg_types` don't specify the possible Expression types of children. This can be a challenge when you are writing code to traverse a particular AST and you don't know what to expect. A common trick is to parse an example query and print out the `repr`. + +You can traverse an AST using just args, but there are some higher-order functions for programmatic traversal. + +> [!NOTE] +> SQLGlot can parse and generate SQL for many different dialects. However, there is only a single set of Expression types for all dialects. We like to say that the AST can represent the _superset_ of all dialects. +> +> Sometimes, SQLGlot will parse SQL from a dialect into Expression types you didn't expect: +> +> ```python +> ast = parse_one("SELECT NOW()", dialect="postgres") +> +> repr(ast) +> # (SELECT expressions: +> # (CURRENTTIMESTAMP )) -- CURRENTTIMESTAMP, not NOW() +> ``` +> +> This is because SQLGlot tries to converge dialects on a standard AST. This means you can often write one piece of code that handles multiple dialects. + +## Traversing the AST + +Analyzing a SQL statement requires traversing this data structure. There are a few ways to do this: + +### Args + +If you know the structure of an AST, you can use `Expression.args` just like above. However, this can be very limited if you're dealing with arbitrary SQL. + +### Walk methods + +The walk methods of `Expression` (`find`, `find_all`, and `walk`) are the simplest way to analyze an AST. + +`find` and `find_all` search an AST for specific Expression types: +```python +from sqlglot import exp + +ast.find(exp.Select) +# (SELECT expressions: +# (COLUMN this: +# (IDENTIFIER this: a, quoted: False)), from: +# ... + +list(ast.find_all(exp.Select)) +# [(SELECT expressions: +# (COLUMN this: +# (IDENTIFIER this: a, quoted: False)), from: +# ... +``` + +Both `find` and `find_all` are built on `walk`, which gives finer grained control: +```python +for ( + node, # the current AST node + parent, # parent of the current AST node (this will be None for the root node) + key # The 'key' of this node in its parent's args +) in ast.walk(): + ... +``` + +> [!WARNING] +> Here's a common pitfall of the walk methods: +> ```python +> ast.find_all(exp.Table) +> ``` +> At first glance, this seems like a great way to find all tables in a query. However, `Table` instances are not always tables in your database. Here's an example where this fails: +> ```python +> ast = parse_one(""" +> WITH x AS ( +> SELECT a FROM y +> ) +> SELECT a FROM x +> """) +> +> # This is NOT a good way to find all tables in the query! +> for table in ast.find_all(exp.Table): +> print(table) +> +> # x -- this is a common table expression, NOT an actual table +> # y +> ``` +> +> For programmatic traversal of ASTs that requires deeper semantic understanding of a query, you need "scope". + +### Scope + +Scope is a traversal module that handles more semantic context of SQL queries. It's harder to use than the `walk` methods but is more powerful: +```python +from sqlglot.optimizer.scope import build_scope + +ast = parse_one(""" +WITH x AS ( + SELECT a FROM y +) +SELECT a FROM x +""") + +root = build_scope(ast) +for scope in root.traverse(): + print(scope) + +# Scope<SELECT a FROM y> +# Scope<WITH x AS (SELECT a FROM y) SELECT a FROM x> +``` + +Let's use this for a better way to find all tables in a query: +```python +tables = [ + source + + # Traverse the Scope tree, not the AST + for scope in root.traverse() + + # `selected_sources` contains sources that have been selected in this scope, e.g. in a FROM or JOIN clause. + # `alias` is the name of this source in this particular scope. + # `node` is the AST node instance + # if the selected source is a subquery (including common table expressions), + # then `source` will be the Scope instance for that subquery. + # if the selected source is a table, + # then `source` will be a Table instance. + for alias, (node, source) in scope.selected_sources.items() + if isinstance(source, exp.Table) +] + +for table in tables: + print(table) + +# y -- Success! +``` + +`build_scope` returns an instance of the `Scope` class. `Scope` has numerous methods for inspecting a query. The best way to browse these methods is directly in the code at [scope.py](../sqlglot/optimizer/scope.py). You can also look for examples of how Scope is used throughout SQLGlot's [optimizer](../sqlglot/optimizer) module. + +Many methods of Scope depend on a fully qualified SQL expression. For example, let's say we want to trace the lineage of columns in this query: +```python +ast = parse_one(""" +SELECT + a, + c +FROM ( + SELECT + a, + b + FROM x +) AS x +JOIN ( + SELECT + b, + c + FROM y +) AS y + ON x.b = y.b +""") +``` + +Just looking at the outer query, it's not obvious that column `a` comes from table `x` without looking at the columns of the subqueries. + +We can use the [qualify](../sqlglot/optimizer/qualify.py) function to prefix all columns in an AST with their table name like so: +```python +from sqlglot.optimizer.qualify import qualify + +qualify(ast) +# SELECT +# x.a AS a, +# y.c AS c +# FROM ( +# SELECT +# x.a AS a, +# x.b AS b +# FROM x AS x +# ) AS x +# JOIN ( +# SELECT +# y.b AS b, +# y.c AS c +# FROM y AS y +# ) AS y +# ON x.b = y.b +``` + +Now we can trace a column to its source. Here's how we might find the table or subquery for all columns in a qualified AST: +```python +from sqlglot.optimizer.scope import find_all_in_scope + +root = build_scope(ast) + +# `find_all_in_scope` is similar to `Expression.find_all`, except it doesn't traverse into subqueries +for column in find_all_in_scope(root.expression, exp.Column): + print(f"{column} => {root.sources[column.table]}") + +# x.a => Scope<SELECT x.a AS a, x.b AS b FROM x AS x> +# y.c => Scope<SELECT y.b AS b, y.c AS c FROM y AS y> +# x.b => Scope<SELECT x.a AS a, x.b AS b FROM x AS x> +# y.b => Scope<SELECT y.b AS b, y.c AS c FROM y AS y> +``` + +For a complete example of tracing column lineage, check out the [lineage](../sqlglot/lineage.py) module. + +> [!NOTE] +> Some queries require the database schema for disambiguation. For example: +> +> ```sql +> SELECT a FROM x CROSS JOIN y +> ``` +> +> Column `a` might come from table `x` or `y`. In these cases, you must pass the `schema` into `qualify`. + +## Mutating the tree + +You can also modify an AST or build one from scratch. There are a few ways to do this. + +### High-level builder methods + +SQLGlot has methods for programmatically building up expressions similar to how you might in an ORM: +```python +ast = ( + exp + .select("a", "b") + .from_("x") + .where("b < 4") + .limit(10) +) +``` + +> [!WARNING] +> High-level builder methods will attempt to parse string arguments into Expressions. This can be very convenient, but make sure to keep in mind the dialect of the string. If its written in a specific dialect, you need to set the `dialect` argument. +> +> You can avoid parsing by passing Expressions as arguments, e.g. `.where(exp.column("b") < 4)` instead of `.where("b < 4")` + +These methods can be used on any AST, including ones you've parsed: +```python +ast = parse_one(""" +SELECT * FROM (SELECT a, b FROM x) +""") + +# To modify the AST in-place, set `copy=False` +ast.args["from"].this.this.select("c", copy=False) + +print(ast) +# SELECT * FROM (SELECT a, b, c FROM x) +``` + +The best place to browse all the available high-level builder methods and their parameters is, as always, directly in the code at [expressions.py](../sqlglot/expressions.py). + +### Low-level builder methods + +High-level builder methods don't account for all possible expressions you might want to build. In the case where a particular high-level method is missing, use the low-level methods. Here are some examples: +```python +node = ast.args["from"].this.this + +# These all do the same thing: + +# high-level +node.select("c", copy=False) +# low-level +node.set("expressions", node.expressions + [exp.column("c")]) +node.append("expressions", exp.column("c")) +node.replace(node.copy().select("c")) +``` +> [!NOTE] +> In general, you should use `Expression.set` and `Expression.append` instead of mutating `Expression.args` directly. `set` and `append` take care to update node references like `parent`. + +You can also instantiate AST nodes directly: + +```python +col = exp.Column( + this=exp.to_identifier("c") +) +node.append("expressions", col) +``` + +> [!WARNING] +> Because SQLGlot doesn't verify the types of args, it's easy to instantiate an invalid AST Node that won't generate to SQL properly. Take extra care to inspect the expected types of a node using the methods described above. + +### Transform + +The `Expression.transform` method applies a function to all nodes in an AST in depth-first, pre-order. + +```python +def transformer(node): + if isinstance(node, exp.Column) and node.name == "a": + # Return a new node to replace `node` + return exp.func("FUN", node) + # Or return `node` to do nothing and continue traversing the tree + return node + +print(parse_one("SELECT a, b FROM x").transform(transformer)) +# SELECT FUN(a), b FROM x +``` + +> [!WARNING] +> As with the walk methods, `transform` doesn't manage scope. For safely transforming the columns and tables in complex expressions, you should probably use Scope. + +## Summed up + +SQLGlot parses SQL statements into an abstract syntax tree (AST) where nodes are instances of `sqlglot.Expression`. + +There are 3 ways to traverse an AST: +1. **args** - use this when you know the exact structure of the AST you're dealing with. +2. **walk methods** - this is the easiest way. Use this for simple cases. +3. **scope** - this is the hardest way. Use this for more complex cases that must handle the semantic context of a query. + +There are 3 ways to mutate an AST +1. **high-level builder methods** - use this when you know the exact structure of the AST you're dealing with. +2. **low-level builder methods** - use this only when high-level builder methods don't exist for what you're trying to build. +3. **transform** - use this for simple transformations on arbitrary statements. + +And, of course, these mechanisms can be mixed and matched. For example, maybe you need to use scope to traverse an arbitrary AST and the high-level builder methods to mutate it in-place. + +Still need help? [Get in touch!](../README.md#get-in-touch) diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 6671c5b..6658287 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -418,11 +418,11 @@ def percentile_approx( def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_anonymous_function(seed, "RAND") + return Column.invoke_expression_over_column(seed, expression.Rand) def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_anonymous_function(seed, "RANDN") + return Column.invoke_expression_over_column(seed, expression.Randn) def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1b06cbf..7a573e7 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -727,7 +727,8 @@ class BigQuery(Dialect): def eq_sql(self, expression: exp.EQ) -> str: # Operands of = cannot be NULL in BigQuery if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): - return "NULL" + if not isinstance(expression.parent, exp.Update): + return "NULL" return self.binary(expression, "=") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 7a3f897..870f402 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -105,6 +105,7 @@ class ClickHouse(Dialect): ), "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, + "RANDCANONICAL": exp.Rand.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, "XOR": lambda args: exp.Xor(expressions=args), } @@ -142,9 +143,10 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.ANY, - TokenType.SETTINGS, - TokenType.FORMAT, TokenType.ARRAY, + TokenType.FINAL, + TokenType.FORMAT, + TokenType.SETTINGS, } LOG_DEFAULTS_TO_LN = True @@ -397,6 +399,7 @@ class ClickHouse(Dialect): exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", + exp.Rand: rename_func("randCanonical"), exp.StartsWith: rename_func("startsWith"), exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 41afad8..cd9d529 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -352,6 +352,7 @@ class DuckDB(Dialect): ), exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), + exp.Rand: rename_func("RANDOM"), exp.SafeDivide: no_safe_divide_sql, exp.Split: rename_func("STR_SPLIT"), exp.SortArray: _sort_array_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index bf65edf..e274877 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -445,6 +445,7 @@ class Postgres(Dialect): ), exp.Pivot: no_pivot_sql, exp.Pow: lambda self, e: self.binary(e, "^"), + exp.Rand: rename_func("RANDOM"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index f09a990..8925181 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -558,6 +558,7 @@ class Snowflake(Dialect): [transforms.add_within_group_for_percentiles] ), exp.RegexpILike: _regexpilike_sql, + exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index e55a3b8..9bac51c 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -127,6 +127,7 @@ class SQLite(Dialect): exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), exp.Pivot: no_pivot_sql, + exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8246769..ea2255d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4988,6 +4988,15 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} +class Rand(Func): + _sql_names = ["RAND", "RANDOM"] + arg_types = {"this": False} + + +class Randn(Func): + arg_types = {"this": False} + + class RangeN(Func): arg_types = {"this": True, "expressions": True, "each": False} @@ -6475,7 +6484,7 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool = raise ValueError(f"Cannot parse {table}") return ".".join( - part.sql(dialect=dialect, identify=True) + part.sql(dialect=dialect, identify=True, copy=False) if identify or not SAFE_IDENTIFIER_RE.match(part.name) else part.name for part in table.parts diff --git a/sqlglot/generator.py b/sqlglot/generator.py index c571e8f..b0e83d2 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import re import typing as t from collections import defaultdict from functools import reduce @@ -17,6 +18,8 @@ if t.TYPE_CHECKING: logger = logging.getLogger("sqlglot") +ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)") + class Generator: """ @@ -917,11 +920,19 @@ class Generator: def unicodestring_sql(self, expression: exp.UnicodeString) -> str: this = self.sql(expression, "this") + escape = expression.args.get("escape") + if self.dialect.UNICODE_START: - escape = self.sql(expression, "escape") - escape = f" UESCAPE {escape}" if escape else "" + escape = f" UESCAPE {self.sql(escape)}" if escape else "" return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}" - return this + + if escape: + pattern = re.compile(rf"{escape.name}(\d+)") + else: + pattern = ESCAPED_UNICODE_RE + + this = pattern.sub(r"\\u\1", this) + return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}" def rawstring_sql(self, expression: exp.RawString) -> str: string = self.escape_str(expression.this.replace("\\", "\\\\")) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6ae08d0..f53023c 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -49,32 +49,32 @@ def simplify( dialect = Dialect.get_or_raise(dialect) - # group by expressions cannot be simplified, for example - # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 - # the projection must exactly match the group by key - for group in expression.find_all(exp.Group): - select = group.parent - assert select - groups = set(group.expressions) - group.meta[FINAL] = True - - for e in select.expressions: - for node, *_ in e.walk(): - if node in groups: - e.meta[FINAL] = True - break - - having = select.args.get("having") - if having: - for node, *_ in having.walk(): - if node in groups: - having.meta[FINAL] = True - break - def _simplify(expression, root=True): if expression.meta.get(FINAL): return expression + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + group = expression.args.get("group") + + if group and hasattr(expression, "selects"): + groups = set(group.expressions) + group.meta[FINAL] = True + + for e in expression.selects: + for node, *_ in e.walk(): + if node in groups: + e.meta[FINAL] = True + break + + having = expression.args.get("having") + if having: + for node, *_ in having.walk(): + if node in groups: + having.meta[FINAL] = True + break + # Pre-order transformations node = expression node = rewrite_between(node) @@ -266,6 +266,8 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.GTE: exp.LTE, } +NONDETERMINISTIC = (exp.Rand, exp.Randn) + def _simplify_comparison(expression, left, right, or_=False): if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): @@ -276,7 +278,7 @@ def _simplify_comparison(expression, left, right, or_=False): rargs = {rl, rr} matching = largs & rargs - columns = {m for m in matching if isinstance(m, exp.Column)} + columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} if matching and columns: try: @@ -292,7 +294,12 @@ def _simplify_comparison(expression, left, right, or_=False): l = l.name r = r.name else: - return None + l = extract_date(l) + if not l: + return None + r = extract_date(r) + if not r: + return None for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 3d01a84..311c43d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -305,6 +305,7 @@ class Parser(metaclass=_Parser): TokenType.FALSE, TokenType.FIRST, TokenType.FILTER, + TokenType.FINAL, TokenType.FORMAT, TokenType.FULL, TokenType.IS, diff --git a/sqlglotrs/src/lib.rs b/sqlglotrs/src/lib.rs index c962887..43e90dc 100644 --- a/sqlglotrs/src/lib.rs +++ b/sqlglotrs/src/lib.rs @@ -71,7 +71,18 @@ impl Token { impl Token { #[pyo3(name = "__repr__")] fn python_repr(&self) -> PyResult<String> { - Ok(format!("{:?}", self)) + Python::with_gil(|py| { + Ok(format!( + "<Token token_type: {}, text: {}, line: {}, col: {}, start: {}, end: {}, comments: {}>", + self.token_type_py.as_ref(py).repr()?, + self.text.as_ref(py).repr()?, + self.line, + self.col, + self.start, + self.end, + self.comments.as_ref(py).repr()?, + )) + }) } } diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 420803a..f263581 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -206,6 +206,7 @@ class TestBigQuery(Validator): "duckdb": "SELECT * FROM t WHERE EXISTS(SELECT * FROM UNNEST(nums) AS _t(x) WHERE x > 1)", }, ) + self.validate_identity("UPDATE x SET y = NULL") self.validate_all( "NULL", read={ diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py index 1f528b6..bb993b5 100644 --- a/tests/dialects/test_clickhouse.py +++ b/tests/dialects/test_clickhouse.py @@ -47,6 +47,7 @@ class TestClickhouse(Validator): self.validate_identity("SELECT INTERVAL t.days day") self.validate_identity("SELECT match('abc', '([a-z]+)')") self.validate_identity("dictGet(x, 'y')") + self.validate_identity("SELECT * FROM final") self.validate_identity("SELECT * FROM x FINAL") self.validate_identity("SELECT * FROM x AS y FINAL") self.validate_identity("'a' IN mapKeys(map('a', 1, 'b', 2))") diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py index 49afc62..a49d067 100644 --- a/tests/dialects/test_dialect.py +++ b/tests/dialects/test_dialect.py @@ -2056,3 +2056,40 @@ SELECT self.assertEqual(expression.sql(dialect="mysql"), expected_sql) self.assertEqual(expression.sql(dialect="tsql"), expected_sql) + + def test_random(self): + self.validate_all( + "RAND()", + write={ + "bigquery": "RAND()", + "clickhouse": "randCanonical()", + "databricks": "RAND()", + "doris": "RAND()", + "drill": "RAND()", + "duckdb": "RANDOM()", + "hive": "RAND()", + "mysql": "RAND()", + "oracle": "RAND()", + "postgres": "RANDOM()", + "presto": "RAND()", + "spark": "RAND()", + "sqlite": "RANDOM()", + "tsql": "RAND()", + }, + read={ + "bigquery": "RAND()", + "clickhouse": "randCanonical()", + "databricks": "RAND()", + "doris": "RAND()", + "drill": "RAND()", + "duckdb": "RANDOM()", + "hive": "RAND()", + "mysql": "RAND()", + "oracle": "RAND()", + "postgres": "RANDOM()", + "presto": "RAND()", + "spark": "RAND()", + "sqlite": "RANDOM()", + "tsql": "RAND()", + }, + ) diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py index 97a387c..8b5080c 100644 --- a/tests/dialects/test_presto.py +++ b/tests/dialects/test_presto.py @@ -546,13 +546,21 @@ class TestPresto(Validator): def test_unicode_string(self): for prefix in ("u&", "U&"): - self.validate_identity( + self.validate_all( f"{prefix}'Hello winter \\2603 !'", - "U&'Hello winter \\2603 !'", + write={ + "presto": "U&'Hello winter \\2603 !'", + "snowflake": "'Hello winter \\u2603 !'", + "spark": "'Hello winter \\u2603 !'", + }, ) - self.validate_identity( + self.validate_all( f"{prefix}'Hello winter #2603 !' UESCAPE '#'", - "U&'Hello winter #2603 !' UESCAPE '#'", + write={ + "presto": "U&'Hello winter #2603 !' UESCAPE '#'", + "snowflake": "'Hello winter \\u2603 !'", + "spark": "'Hello winter \\u2603 !'", + }, ) def test_presto(self): diff --git a/tests/fixtures/optimizer/simplify.sql b/tests/fixtures/optimizer/simplify.sql index fbf5d2c..d3b03fb 100644 --- a/tests/fixtures/optimizer/simplify.sql +++ b/tests/fixtures/optimizer/simplify.sql @@ -696,6 +696,18 @@ x <> 1; NOT 1 <> x; x = 1; +x > CAST('2024-01-01' AS DATE) OR x > CAST('2023-12-31' AS DATE); +x > CAST('2023-12-31' AS DATE); + +CAST(x AS DATE) > CAST('2024-01-01' AS DATE) OR CAST(x AS DATE) > CAST('2023-12-31' AS DATE); +CAST(x AS DATE) > CAST('2023-12-31' AS DATE); + +FUN() > 0 OR FUN() > 1; +FUN() > 0; + +RAND() > 0 OR RAND() > 1; +RAND() > 0 OR RAND() > 1; + -------------------------------------- -- COALESCE -------------------------------------- @@ -835,7 +847,7 @@ DATE_TRUNC('quarter', x) = CAST('2021-01-02' AS DATE); DATE_TRUNC('quarter', x) = CAST('2021-01-02' AS DATE); DATE_TRUNC('year', x) <> CAST('2021-01-01' AS DATE); -x < CAST('2021-01-01' AS DATE) AND x >= CAST('2022-01-01' AS DATE); +FALSE; -- Always true, except for nulls DATE_TRUNC('year', x) <> CAST('2021-01-02' AS DATE); |