diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 45 |
1 files changed, 14 insertions, 31 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ed0a681..95db795 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2,6 +2,7 @@ from __future__ import annotations import logging import typing as t +from collections import defaultdict from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages @@ -676,15 +677,13 @@ class Generator: this = f" {this}" if this else "" return f"UNIQUE{this}" - def createable_sql( - self, expression: exp.Create, locations: dict[exp.Properties.Location, list[exp.Property]] - ) -> str: + def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str: return self.sql(expression, "this") def create_sql(self, expression: exp.Create) -> str: kind = self.sql(expression, "kind").upper() properties = expression.args.get("properties") - properties_locs = self.locate_properties(properties) if properties else {} + properties_locs = self.locate_properties(properties) if properties else defaultdict() this = self.createable_sql(expression, properties_locs) @@ -970,9 +969,9 @@ class Generator: for p in expression.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] if p_loc == exp.Properties.Location.POST_WITH: - with_properties.append(p) + with_properties.append(p.copy()) elif p_loc == exp.Properties.Location.POST_SCHEMA: - root_properties.append(p) + root_properties.append(p.copy()) return self.root_properties( exp.Properties(expressions=root_properties) @@ -1001,30 +1000,13 @@ class Generator: def with_properties(self, properties: exp.Properties) -> str: return self.properties(properties, prefix=self.seg("WITH")) - def locate_properties( - self, properties: exp.Properties - ) -> t.Dict[exp.Properties.Location, list[exp.Property]]: - properties_locs: t.Dict[exp.Properties.Location, list[exp.Property]] = { - key: [] for key in exp.Properties.Location - } - + def locate_properties(self, properties: exp.Properties) -> t.DefaultDict: + properties_locs = defaultdict(list) for p in properties.expressions: p_loc = self.PROPERTIES_LOCATION[p.__class__] - if p_loc == exp.Properties.Location.POST_NAME: - properties_locs[exp.Properties.Location.POST_NAME].append(p) - elif p_loc == exp.Properties.Location.POST_INDEX: - properties_locs[exp.Properties.Location.POST_INDEX].append(p) - elif p_loc == exp.Properties.Location.POST_SCHEMA: - properties_locs[exp.Properties.Location.POST_SCHEMA].append(p) - elif p_loc == exp.Properties.Location.POST_WITH: - properties_locs[exp.Properties.Location.POST_WITH].append(p) - elif p_loc == exp.Properties.Location.POST_CREATE: - properties_locs[exp.Properties.Location.POST_CREATE].append(p) - elif p_loc == exp.Properties.Location.POST_ALIAS: - properties_locs[exp.Properties.Location.POST_ALIAS].append(p) - elif p_loc == exp.Properties.Location.POST_EXPRESSION: - properties_locs[exp.Properties.Location.POST_EXPRESSION].append(p) - elif p_loc == exp.Properties.Location.UNSUPPORTED: + if p_loc != exp.Properties.Location.UNSUPPORTED: + properties_locs[p_loc].append(p.copy()) + else: self.unsupported(f"Unsupported property {p.key}") return properties_locs @@ -1646,9 +1628,9 @@ class Generator: with_offset_limit_modifiers = not isinstance(limit, exp.Limit) or not self.LIMIT_IS_TOP if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): - limit = exp.Limit(expression=limit.args.get("count")) + limit = exp.Limit(expression=exp.maybe_copy(limit.args.get("count"))) elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): - limit = exp.Fetch(direction="FIRST", count=limit.expression) + limit = exp.Fetch(direction="FIRST", count=exp.maybe_copy(limit.expression)) fetch = isinstance(limit, exp.Fetch) @@ -1955,6 +1937,7 @@ class Generator: return f"PRIMARY KEY ({expressions}){options}" def if_sql(self, expression: exp.If) -> str: + expression = expression.copy() return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) def matchagainst_sql(self, expression: exp.MatchAgainst) -> str: @@ -2261,7 +2244,7 @@ class Generator: def intdiv_sql(self, expression: exp.IntDiv) -> str: return self.sql( exp.Cast( - this=exp.Div(this=expression.this, expression=expression.expression), + this=exp.Div(this=expression.this.copy(), expression=expression.expression.copy()), to=exp.DataType(this=exp.DataType.Type.INT), ) ) |