summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-11-01 05:12:42 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-11-01 05:12:42 +0000
commitc51a9844b869fd7cd69e5cc7658d34f61a865185 (patch)
tree55706c65ce7e19626aabf7ff4dde0e1a51b739db /sqlglot/generator.py
parentReleasing debian version 18.17.0-1. (diff)
downloadsqlglot-c51a9844b869fd7cd69e5cc7658d34f61a865185.tar.xz
sqlglot-c51a9844b869fd7cd69e5cc7658d34f61a865185.zip
Merging upstream version 19.0.1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py125
1 files changed, 76 insertions, 49 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 0d6778a..4916cf8 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -230,6 +230,12 @@ class Generator:
# Whether or not data types support additional specifiers like e.g. CHAR or BYTE (oracle)
DATA_TYPE_SPECIFIERS_ALLOWED = False
+ # Whether or not nested CTEs (e.g. defined inside of subqueries) are allowed
+ SUPPORTS_NESTED_CTES = True
+
+ # Whether or not the "RECURSIVE" keyword is required when defining recursive CTEs
+ CTE_RECURSIVE_KEYWORD_REQUIRED = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -304,6 +310,7 @@ class Generator:
exp.Order: exp.Properties.Location.POST_SCHEMA,
exp.OutputModelProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
+ exp.PartitionedOfProperty: exp.Properties.Location.POST_SCHEMA,
exp.PrimaryKey: exp.Properties.Location.POST_SCHEMA,
exp.Property: exp.Properties.Location.POST_WITH,
exp.RemoteWithConnectionModelProperty: exp.Properties.Location.POST_SCHEMA,
@@ -407,7 +414,6 @@ class Generator:
"unsupported_messages",
"_escaped_quote_end",
"_escaped_identifier_end",
- "_cache",
)
def __init__(
@@ -447,30 +453,38 @@ class Generator:
self._escaped_identifier_end: str = (
self.TOKENIZER_CLASS.IDENTIFIER_ESCAPES[0] + self.IDENTIFIER_END
)
- self._cache: t.Optional[t.Dict[int, str]] = None
- def generate(
- self,
- expression: t.Optional[exp.Expression],
- cache: t.Optional[t.Dict[int, str]] = None,
- ) -> str:
+ def generate(self, expression: exp.Expression, copy: bool = True) -> str:
"""
Generates the SQL string corresponding to the given syntax tree.
Args:
expression: The syntax tree.
- cache: An optional sql string cache. This leverages the hash of an Expression
- which can be slow to compute, so only use it if you set _hash on each node.
+ copy: Whether or not to copy the expression. The generator performs mutations so
+ it is safer to copy.
Returns:
The SQL string corresponding to `expression`.
"""
- if cache is not None:
- self._cache = cache
+ if copy:
+ expression = expression.copy()
+
+ # Some dialects only support CTEs at the top level expression, so we need to bubble up nested
+ # CTEs to that level in order to produce a syntactically valid expression. This transformation
+ # happens here to minimize code duplication, since many expressions support CTEs.
+ if (
+ not self.SUPPORTS_NESTED_CTES
+ and isinstance(expression, exp.Expression)
+ and not expression.parent
+ and "with" in expression.arg_types
+ and any(node.parent is not expression for node in expression.find_all(exp.With))
+ ):
+ from sqlglot.transforms import move_ctes_to_top_level
+
+ expression = move_ctes_to_top_level(expression)
self.unsupported_messages = []
sql = self.sql(expression).strip()
- self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@@ -595,12 +609,6 @@ class Generator:
return self.sql(value)
return ""
- if self._cache is not None:
- expression_id = hash(expression)
-
- if expression_id in self._cache:
- return self._cache[expression_id]
-
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@@ -621,11 +629,7 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- sql = self.maybe_comment(sql, expression) if self.comments and comment else sql
-
- if self._cache is not None:
- self._cache[expression_id] = sql
- return sql
+ return self.maybe_comment(sql, expression) if self.comments and comment else sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@@ -879,7 +883,11 @@ class Generator:
def with_sql(self, expression: exp.With) -> str:
sql = self.expressions(expression, flat=True)
- recursive = "RECURSIVE " if expression.args.get("recursive") else ""
+ recursive = (
+ "RECURSIVE "
+ if self.CTE_RECURSIVE_KEYWORD_REQUIRED and expression.args.get("recursive")
+ else ""
+ )
return f"WITH {recursive}{sql}"
@@ -1022,7 +1030,7 @@ class Generator:
where = self.sql(expression, "expression").strip()
return f"{this} FILTER({where})"
- agg = expression.this.copy()
+ agg = expression.this
agg_arg = agg.this
cond = expression.expression.this
agg_arg.replace(exp.If(this=cond.copy(), true=agg_arg.copy()))
@@ -1088,9 +1096,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
- with_properties.append(p.copy())
+ with_properties.append(p)
elif p_loc == exp.Properties.Location.POST_SCHEMA:
- root_properties.append(p.copy())
+ root_properties.append(p)
return self.root_properties(
exp.Properties(expressions=root_properties)
@@ -1124,7 +1132,7 @@ class Generator:
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc != exp.Properties.Location.UNSUPPORTED:
- properties_locs[p_loc].append(p.copy())
+ properties_locs[p_loc].append(p)
else:
self.unsupported(f"Unsupported property {p.key}")
@@ -1238,6 +1246,29 @@ class Generator:
for_ = " FOR NONE"
return f"WITH{no}{concurrent} ISOLATED LOADING{for_}"
+ def partitionboundspec_sql(self, expression: exp.PartitionBoundSpec) -> str:
+ if isinstance(expression.this, list):
+ return f"IN ({self.expressions(expression, key='this', flat=True)})"
+ if expression.this:
+ modulus = self.sql(expression, "this")
+ remainder = self.sql(expression, "expression")
+ return f"WITH (MODULUS {modulus}, REMAINDER {remainder})"
+
+ from_expressions = self.expressions(expression, key="from_expressions", flat=True)
+ to_expressions = self.expressions(expression, key="to_expressions", flat=True)
+ return f"FROM ({from_expressions}) TO ({to_expressions})"
+
+ def partitionedofproperty_sql(self, expression: exp.PartitionedOfProperty) -> str:
+ this = self.sql(expression, "this")
+
+ for_values_or_default = expression.expression
+ if isinstance(for_values_or_default, exp.PartitionBoundSpec):
+ for_values_or_default = f" FOR VALUES {self.sql(for_values_or_default)}"
+ else:
+ for_values_or_default = " DEFAULT"
+
+ return f"PARTITION OF {this}{for_values_or_default}"
+
def lockingproperty_sql(self, expression: exp.LockingProperty) -> str:
kind = expression.args.get("kind")
this = f" {self.sql(expression, 'this')}" if expression.this else ""
@@ -1385,7 +1416,12 @@ class Generator:
index = self.sql(expression, "index")
index = f" AT {index}" if index else ""
- return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}"
+ ordinality = expression.args.get("ordinality") or ""
+ if ordinality:
+ ordinality = f" WITH ORDINALITY{alias}"
+ alias = ""
+
+ return f"{table}{version}{file_format}{alias}{index}{hints}{pivots}{joins}{laterals}{ordinality}"
def tablesample_sql(
self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
@@ -1489,7 +1525,6 @@ class Generator:
return f"{values} AS {alias}" if alias else values
# Converts `VALUES...` expression into a series of select unions.
- expression = expression.copy()
alias_node = expression.args.get("alias")
column_names = alias_node and alias_node.columns
@@ -1972,8 +2007,7 @@ class Generator:
if self.UNNEST_WITH_ORDINALITY:
if alias and isinstance(offset, exp.Expression):
- alias = alias.copy()
- alias.append("columns", offset.copy())
+ alias.append("columns", offset)
if alias and self.UNNEST_COLUMN_ONLY:
columns = alias.columns
@@ -2138,7 +2172,6 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
- expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
@@ -2367,7 +2400,9 @@ class Generator:
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
format_sql = self.sql(expression, "format")
format_sql = f" FORMAT {format_sql}" if format_sql else ""
- return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')}{format_sql})"
+ to_sql = self.sql(expression, "to")
+ to_sql = f" {to_sql}" if to_sql else ""
+ return f"{safe_prefix or ''}CAST({self.sql(expression, 'this')} AS{to_sql}{format_sql})"
def currentdate_sql(self, expression: exp.CurrentDate) -> str:
zone = self.sql(expression, "this")
@@ -2510,7 +2545,7 @@ class Generator:
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
- this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
+ this=exp.Div(this=expression.this, expression=expression.expression),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)
@@ -2779,7 +2814,6 @@ class Generator:
hints = table.args.get("hints")
if hints and table.alias and isinstance(hints[0], exp.WithTableHint):
# T-SQL syntax is MERGE ... <target_table> [WITH (<merge_hint>)] [[AS] table_alias]
- table = table.copy()
table_alias = f" AS {self.sql(table.args['alias'].pop())}"
this = self.sql(table)
@@ -2787,7 +2821,9 @@ class Generator:
on = f"ON {self.sql(expression, 'on')}"
expressions = self.expressions(expression, sep=" ")
- return f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
+ return self.prepend_ctes(
+ expression, f"MERGE INTO {this}{table_alias} {using} {on} {expressions}"
+ )
def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
@@ -2896,12 +2932,12 @@ class Generator:
case = exp.Case().when(
expression.this.is_(exp.null()).not_(copy=False),
- expression.args["true"].copy(),
+ expression.args["true"],
copy=False,
)
else_cond = expression.args.get("false")
if else_cond:
- case.else_(else_cond.copy(), copy=False)
+ case.else_(else_cond, copy=False)
return self.sql(case)
@@ -2931,15 +2967,6 @@ class Generator:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify
- expression = simplify(expression.copy())
+ expression = simplify(expression)
return expression
-
-
-def cached_generator(
- cache: t.Optional[t.Dict[int, str]] = None
-) -> t.Callable[[exp.Expression], str]:
- """Returns a cached generator."""
- cache = {} if cache is None else cache
- generator = Generator(normalize=True, identify="safe")
- return lambda e: generator.generate(e, cache)