summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py45
1 files changed, 14 insertions, 31 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index ed0a681..95db795 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -2,6 +2,7 @@ from __future__ import annotations
import logging
import typing as t
+from collections import defaultdict
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
@@ -676,15 +677,13 @@ class Generator:
this = f" {this}" if this else ""
return f"UNIQUE{this}"
- def createable_sql(
- self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]]
- ) -> str:
+ def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
return self.sql(expression, "this")
def create_sql(self, expression: exp.Create) -> str:
kind = self.sql(expression, "kind").upper()
properties = expression.args.get("properties")
- properties_locs = self.locate_properties(properties) if properties else {}
+ properties_locs = self.locate_properties(properties) if properties else defaultdict()
this = self.createable_sql(expression, properties_locs)
@@ -970,9 +969,9 @@ class Generator:
for p in expression.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
if p_loc == exp.Properties.Location.POST_WITH:
- with_properties.append(p)
+ with_properties.append(p.copy())
elif p_loc == exp.Properties.Location.POST_SCHEMA:
- root_properties.append(p)
+ root_properties.append(p.copy())
return self.root_properties(
exp.Properties(expressions=root_properties)
@@ -1001,30 +1000,13 @@ class Generator:
def with_properties(self, properties: exp.Properties) -> str:
return self.properties(properties, prefix=self.seg("WITH"))
- def locate_properties(
- self, properties: exp.Properties
- ) -> t.Dict[exp.Properties.Location, list[exp.Property]]:
- properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = {
- key: [] for key in exp.Properties.Location
- }
-
+ def locate_properties(self, properties: exp.Properties) -> t.DefaultDict:
+ properties_locs = defaultdict(list)
for p in properties.expressions:
p_loc = self.PROPERTIES_LOCATION[p.__class__]
- if p_loc == exp.Properties.Location.POST_NAME:
- properties_locs[exp.Properties.Location.POST_NAME].append(p)
- elif p_loc == exp.Properties.Location.POST_INDEX:
- properties_locs[exp.Properties.Location.POST_INDEX].append(p)
- elif p_loc == exp.Properties.Location.POST_SCHEMA:
- properties_locs[exp.Properties.Location.POST_SCHEMA].append(p)
- elif p_loc == exp.Properties.Location.POST_WITH:
- properties_locs[exp.Properties.Location.POST_WITH].append(p)
- elif p_loc == exp.Properties.Location.POST_CREATE:
- properties_locs[exp.Properties.Location.POST_CREATE].append(p)
- elif p_loc == exp.Properties.Location.POST_ALIAS:
- properties_locs[exp.Properties.Location.POST_ALIAS].append(p)
- elif p_loc == exp.Properties.Location.POST_EXPRESSION:
- properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p)
- elif p_loc == exp.Properties.Location.UNSUPPORTED:
+ if p_loc != exp.Properties.Location.UNSUPPORTED:
+ properties_locs[p_loc].append(p.copy())
+ else:
self.unsupported(f"Unsupported property {p.key}")
return properties_locs
@@ -1646,9 +1628,9 @@ class Generator:
with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP
if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
- limit = exp.Limit(expression=limit.args.get("count"))
+ limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count")))
elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
- limit = exp.Fetch(direction="FIRST", count=limit.expression)
+ limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression))
fetch = isinstance(limit, exp.Fetch)
@@ -1955,6 +1937,7 @@ class Generator:
return f"PRIMARY KEY ({expressions}){options}"
def if_sql(self, expression: exp.If) -> str:
+ expression = expression.copy()
return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false")))
def matchagainst_sql(self, expression: exp.MatchAgainst) -> str:
@@ -2261,7 +2244,7 @@ class Generator:
def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
- this=exp.Div(this=expression.this, expression=expression.expression),
+ this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()),
to=exp.DataType(this=exp.DataType.Type.INT),
)
)