summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dataframe/sql/functions.py4
-rw-r--r--sqlglot/dialects/bigquery.py3
-rw-r--r--sqlglot/dialects/clickhouse.py7
-rw-r--r--sqlglot/dialects/duckdb.py1
-rw-r--r--sqlglot/dialects/postgres.py1
-rw-r--r--sqlglot/dialects/snowflake.py1
-rw-r--r--sqlglot/dialects/sqlite.py1
-rw-r--r--sqlglot/expressions.py11
-rw-r--r--sqlglot/generator.py17
-rw-r--r--sqlglot/optimizer/simplify.py55
-rw-r--r--sqlglot/parser.py1
11 files changed, 69 insertions, 33 deletions
diff --git a/sqlglot/dataframe/sql/functions.py b/sqlglot/dataframe/sql/functions.py
index 6671c5b..6658287 100644
--- a/sqlglot/dataframe/sql/functions.py
+++ b/sqlglot/dataframe/sql/functions.py
@@ -418,11 +418,11 @@ def percentile_approx(
def rand(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
- return Column.invoke_anonymous_function(seed, "RAND")
+ return Column.invoke_expression_over_column(seed, expression.Rand)
def randn(seed: t.Optional[ColumnOrLiteral] = None) -> Column:
- return Column.invoke_anonymous_function(seed, "RANDN")
+ return Column.invoke_expression_over_column(seed, expression.Randn)
def round(col: ColumnOrName, scale: t.Optional[int] = None) -> Column:
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index 1b06cbf..7a573e7 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -727,7 +727,8 @@ class BigQuery(Dialect):
def eq_sql(self, expression: exp.EQ) -> str:
# Operands of = cannot be NULL in BigQuery
if isinstance(expression.left, exp.Null) or isinstance(expression.right, exp.Null):
- return "NULL"
+ if not isinstance(expression.parent, exp.Update):
+ return "NULL"
return self.binary(expression, "=")
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 7a3f897..870f402 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -105,6 +105,7 @@ class ClickHouse(Dialect):
),
"MAP": parse_var_map,
"MATCH": exp.RegexpLike.from_arg_list,
+ "RANDCANONICAL": exp.Rand.from_arg_list,
"UNIQ": exp.ApproxDistinct.from_arg_list,
"XOR": lambda args: exp.Xor(expressions=args),
}
@@ -142,9 +143,10 @@ class ClickHouse(Dialect):
TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.ANY,
- TokenType.SETTINGS,
- TokenType.FORMAT,
TokenType.ARRAY,
+ TokenType.FINAL,
+ TokenType.FORMAT,
+ TokenType.SETTINGS,
}
LOG_DEFAULTS_TO_LN = True
@@ -397,6 +399,7 @@ class ClickHouse(Dialect):
exp.Pivot: no_pivot_sql,
exp.Quantile: _quantile_sql,
exp.RegexpLike: lambda self, e: f"match({self.format_args(e.this, e.expression)})",
+ exp.Rand: rename_func("randCanonical"),
exp.StartsWith: rename_func("startsWith"),
exp.StrPosition: lambda self, e: f"position({self.format_args(e.this, e.args.get('substr'), e.args.get('position'))})",
exp.VarMap: lambda self, e: _lower_func(var_map_sql(self, e)),
diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py
index 41afad8..cd9d529 100644
--- a/sqlglot/dialects/duckdb.py
+++ b/sqlglot/dialects/duckdb.py
@@ -352,6 +352,7 @@ class DuckDB(Dialect):
),
exp.RegexpLike: rename_func("REGEXP_MATCHES"),
exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"),
+ exp.Rand: rename_func("RANDOM"),
exp.SafeDivide: no_safe_divide_sql,
exp.Split: rename_func("STR_SPLIT"),
exp.SortArray: _sort_array_sql,
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index bf65edf..e274877 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -445,6 +445,7 @@ class Postgres(Dialect):
),
exp.Pivot: no_pivot_sql,
exp.Pow: lambda self, e: self.binary(e, "^"),
+ exp.Rand: rename_func("RANDOM"),
exp.RegexpLike: lambda self, e: self.binary(e, "~"),
exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.Select: transforms.preprocess(
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index f09a990..8925181 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -558,6 +558,7 @@ class Snowflake(Dialect):
[transforms.add_within_group_for_percentiles]
),
exp.RegexpILike: _regexpilike_sql,
+ exp.Rand: rename_func("RANDOM"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index e55a3b8..9bac51c 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -127,6 +127,7 @@ class SQLite(Dialect):
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
exp.Pivot: no_pivot_sql,
+ exp.Rand: rename_func("RANDOM"),
exp.Select: transforms.preprocess(
[
transforms.eliminate_distinct_on,
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 8246769..ea2255d 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -4988,6 +4988,15 @@ class ApproxQuantile(Quantile):
arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
+class Rand(Func):
+ _sql_names = ["RAND", "RANDOM"]
+ arg_types = {"this": False}
+
+
+class Randn(Func):
+ arg_types = {"this": False}
+
+
class RangeN(Func):
arg_types = {"this": True, "expressions": True, "each": False}
@@ -6475,7 +6484,7 @@ def table_name(table: Table | str, dialect: DialectType = None, identify: bool =
raise ValueError(f"Cannot parse {table}")
return ".".join(
- part.sql(dialect=dialect, identify=True)
+ part.sql(dialect=dialect, identify=True, copy=False)
if identify or not SAFE_IDENTIFIER_RE.match(part.name)
else part.name
for part in table.parts
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index c571e8f..b0e83d2 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+import re
import typing as t
from collections import defaultdict
from functools import reduce
@@ -17,6 +18,8 @@ if t.TYPE_CHECKING:
logger = logging.getLogger("sqlglot")
+ESCAPED_UNICODE_RE = re.compile(r"\\(\d+)")
+
class Generator:
"""
@@ -917,11 +920,19 @@ class Generator:
def unicodestring_sql(self, expression: exp.UnicodeString) -> str:
this = self.sql(expression, "this")
+ escape = expression.args.get("escape")
+
if self.dialect.UNICODE_START:
- escape = self.sql(expression, "escape")
- escape = f" UESCAPE {escape}" if escape else ""
+ escape = f" UESCAPE {self.sql(escape)}" if escape else ""
return f"{self.dialect.UNICODE_START}{this}{self.dialect.UNICODE_END}{escape}"
- return this
+
+ if escape:
+ pattern = re.compile(rf"{escape.name}(\d+)")
+ else:
+ pattern = ESCAPED_UNICODE_RE
+
+ this = pattern.sub(r"\\u\1", this)
+ return f"{self.dialect.QUOTE_START}{this}{self.dialect.QUOTE_END}"
def rawstring_sql(self, expression: exp.RawString) -> str:
string = self.escape_str(expression.this.replace("\\", "\\\\"))
diff --git a/sqlglot/optimizer/simplify.py b/sqlglot/optimizer/simplify.py
index 6ae08d0..f53023c 100644
--- a/sqlglot/optimizer/simplify.py
+++ b/sqlglot/optimizer/simplify.py
@@ -49,32 +49,32 @@ def simplify(
dialect = Dialect.get_or_raise(dialect)
- # group by expressions cannot be simplified, for example
- # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
- # the projection must exactly match the group by key
- for group in expression.find_all(exp.Group):
- select = group.parent
- assert select
- groups = set(group.expressions)
- group.meta[FINAL] = True
-
- for e in select.expressions:
- for node, *_ in e.walk():
- if node in groups:
- e.meta[FINAL] = True
- break
-
- having = select.args.get("having")
- if having:
- for node, *_ in having.walk():
- if node in groups:
- having.meta[FINAL] = True
- break
-
def _simplify(expression, root=True):
if expression.meta.get(FINAL):
return expression
+ # group by expressions cannot be simplified, for example
+ # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
+ # the projection must exactly match the group by key
+ group = expression.args.get("group")
+
+ if group and hasattr(expression, "selects"):
+ groups = set(group.expressions)
+ group.meta[FINAL] = True
+
+ for e in expression.selects:
+ for node, *_ in e.walk():
+ if node in groups:
+ e.meta[FINAL] = True
+ break
+
+ having = expression.args.get("having")
+ if having:
+ for node, *_ in having.walk():
+ if node in groups:
+ having.meta[FINAL] = True
+ break
+
# Pre-order transformations
node = expression
node = rewrite_between(node)
@@ -266,6 +266,8 @@ INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
exp.GTE: exp.LTE,
}
+NONDETERMINISTIC = (exp.Rand, exp.Randn)
+
def _simplify_comparison(expression, left, right, or_=False):
if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
@@ -276,7 +278,7 @@ def _simplify_comparison(expression, left, right, or_=False):
rargs = {rl, rr}
matching = largs & rargs
- columns = {m for m in matching if isinstance(m, exp.Column)}
+ columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
if matching and columns:
try:
@@ -292,7 +294,12 @@ def _simplify_comparison(expression, left, right, or_=False):
l = l.name
r = r.name
else:
- return None
+ l = extract_date(l)
+ if not l:
+ return None
+ r = extract_date(r)
+ if not r:
+ return None
for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index 3d01a84..311c43d 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -305,6 +305,7 @@ class Parser(metaclass=_Parser):
TokenType.FALSE,
TokenType.FIRST,
TokenType.FILTER,
+ TokenType.FINAL,
TokenType.FORMAT,
TokenType.FULL,
TokenType.IS,