diff options
Diffstat (limited to 'sqlglot')
-rw-r--r-- | sqlglot/dataframe/sql/functions.py | 4 | ||||
-rw-r--r-- | sqlglot/dialects/bigquery.py | 3 | ||||
-rw-r--r-- | sqlglot/dialects/clickhouse.py | 7 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/postgres.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 1 | ||||
-rw-r--r-- | sqlglot/dialects/sqlite.py | 1 | ||||
-rw-r--r-- | sqlglot/expressions.py | 11 | ||||
-rw-r--r-- | sqlglot/generator.py | 17 | ||||
-rw-r--r-- | sqlglot/optimizer/simplify.py | 55 | ||||
-rw-r--r-- | sqlglot/parser.py | 1 |
11 files changed, 69 insertions, 33 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py index 6671c5b..6658287 100644 --- a/sqlglot/dataframe/sql/functions.py +++ b/sqlglot/dataframe/sql/functions.py @@ -418,11 +418,11 @@ def percentile_approx( def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_anonymous_function(seed, "RAND") + return Column.invoke_expression_over_column(seed, expression.Rand) def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column: - return Column.invoke_anonymous_function(seed, "RANDN") + return Column.invoke_expression_over_column(seed, expression.Randn) def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column: diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index 1b06cbf..7a573e7 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -727,7 +727,8 @@ class BigQuery(Dialect): def eq_sql(self, expression: exp.EQ) -> str: # Operands of = cannot be NULL in BigQuery if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null): - return "NULL" + if not isinstance(expression.parent, exp.Update): + return "NULL" return self.binary(expression, "=") diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index 7a3f897..870f402 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -105,6 +105,7 @@ class ClickHouse(Dialect): ), "MAP": parse_var_map, "MATCH": exp.RegexpLike.from_arg_list, + "RANDCANONICAL": exp.Rand.from_arg_list, "UNIQ": exp.ApproxDistinct.from_arg_list, "XOR": lambda args: exp.Xor(expressions=args), } @@ -142,9 +143,10 @@ class ClickHouse(Dialect): TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - { TokenType.ANY, - TokenType.SETTINGS, - TokenType.FORMAT, TokenType.ARRAY, + TokenType.FINAL, + TokenType.FORMAT, + TokenType.SETTINGS, } LOG_DEFAULTS_TO_LN = True @@ -397,6 +399,7 @@ class ClickHouse(Dialect): exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})", + exp.Rand: rename_func("randCanonical"), exp.StartsWith: rename_func("startsWith"), exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})", exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)), diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 41afad8..cd9d529 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -352,6 +352,7 @@ class DuckDB(Dialect): ), exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), + exp.Rand: rename_func("RANDOM"), exp.SafeDivide: no_safe_divide_sql, exp.Split: rename_func("STR_SPLIT"), exp.SortArray: _sort_array_sql, diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index bf65edf..e274877 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -445,6 +445,7 @@ class Postgres(Dialect): ), exp.Pivot: no_pivot_sql, exp.Pow: lambda self, e: self.binary(e, "^"), + exp.Rand: rename_func("RANDOM"), exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.Select: transforms.preprocess( diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index f09a990..8925181 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -558,6 +558,7 @@ class Snowflake(Dialect): [transforms.add_within_group_for_percentiles] ), exp.RegexpILike: _regexpilike_sql, + exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py index e55a3b8..9bac51c 100644 --- a/sqlglot/dialects/sqlite.py +++ b/sqlglot/dialects/sqlite.py @@ -127,6 +127,7 @@ class SQLite(Dialect): exp.LogicalOr: rename_func("MAX"), exp.LogicalAnd: rename_func("MIN"), exp.Pivot: no_pivot_sql, + exp.Rand: rename_func("RANDOM"), exp.Select: transforms.preprocess( [ transforms.eliminate_distinct_on, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 8246769..ea2255d 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4988,6 +4988,15 @@ class ApproxQuantile(Quantile): arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False} +class Rand(Func): + _sql_names = ["RAND", "RANDOM"] + arg_types = {"this": False} + + +class Randn(Func): + arg_types = {"this": False} + + class RangeN(Func): arg_types = {"this": True, "expressions": True, "each": False} @@ -6475,7 +6484,7 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool = raise ValueError(f"Cannot parse {table}") return ".".join( - part.sql(dialect=dialect, identify=True) + part.sql(dialect=dialect, identify=True, copy=False) if identify or not SAFE_IDENTIFIER_RE.match(part.name) else part.name for part in table.parts diff --git a/sqlglot/generator.py b/sqlglot/generator.py index c571e8f..b0e83d2 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import re import typing as t from collections import defaultdict from functools import reduce @@ -17,6 +18,8 @@ if t.TYPE_CHECKING: logger = logging.getLogger("sqlglot") +ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)") + class Generator: """ @@ -917,11 +920,19 @@ class Generator: def unicodestring_sql(self, expression: exp.UnicodeString) -> str: this = self.sql(expression, "this") + escape = expression.args.get("escape") + if self.dialect.UNICODE_START: - escape = self.sql(expression, "escape") - escape = f" UESCAPE {escape}" if escape else "" + escape = f" UESCAPE {self.sql(escape)}" if escape else "" return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}" - return this + + if escape: + pattern = re.compile(rf"{escape.name}(\d+)") + else: + pattern = ESCAPED_UNICODE_RE + + this = pattern.sub(r"\\u\1", this) + return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}" def rawstring_sql(self, expression: exp.RawString) -> str: string = self.escape_str(expression.this.replace("\\", "\\\\")) diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py index 6ae08d0..f53023c 100644 --- a/sqlglot/optimizer/simplify.py +++ b/sqlglot/optimizer/simplify.py @@ -49,32 +49,32 @@ def simplify( dialect = Dialect.get_or_raise(dialect) - # group by expressions cannot be simplified, for example - # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 - # the projection must exactly match the group by key - for group in expression.find_all(exp.Group): - select = group.parent - assert select - groups = set(group.expressions) - group.meta[FINAL] = True - - for e in select.expressions: - for node, *_ in e.walk(): - if node in groups: - e.meta[FINAL] = True - break - - having = select.args.get("having") - if having: - for node, *_ in having.walk(): - if node in groups: - having.meta[FINAL] = True - break - def _simplify(expression, root=True): if expression.meta.get(FINAL): return expression + # group by expressions cannot be simplified, for example + # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 + # the projection must exactly match the group by key + group = expression.args.get("group") + + if group and hasattr(expression, "selects"): + groups = set(group.expressions) + group.meta[FINAL] = True + + for e in expression.selects: + for node, *_ in e.walk(): + if node in groups: + e.meta[FINAL] = True + break + + having = expression.args.get("having") + if having: + for node, *_ in having.walk(): + if node in groups: + having.meta[FINAL] = True + break + # Pre-order transformations node = expression node = rewrite_between(node) @@ -266,6 +266,8 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { exp.GTE: exp.LTE, } +NONDETERMINISTIC = (exp.Rand, exp.Randn) + def _simplify_comparison(expression, left, right, or_=False): if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): @@ -276,7 +278,7 @@ def _simplify_comparison(expression, left, right, or_=False): rargs = {rl, rr} matching = largs & rargs - columns = {m for m in matching if isinstance(m, exp.Column)} + columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} if matching and columns: try: @@ -292,7 +294,12 @@ def _simplify_comparison(expression, left, right, or_=False): l = l.name r = r.name else: - return None + l = extract_date(l) + if not l: + return None + r = extract_date(r) + if not r: + return None for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 3d01a84..311c43d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -305,6 +305,7 @@ class Parser(metaclass=_Parser): TokenType.FALSE, TokenType.FIRST, TokenType.FILTER, + TokenType.FINAL, TokenType.FORMAT, TokenType.FULL, TokenType.IS, |