diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 92 |
1 files changed, 70 insertions, 22 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 2b4c575..0c1578a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -361,10 +361,11 @@ class Generator: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) + exists = "IF NOT EXISTS " if expression.args.get("exists") else "" if not constraints: - return f"{column} {kind}" - return f"{column} {kind} {constraints}" + return f"{exists}{column} {kind}" + return f"{exists}{column} {kind} {constraints}" def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") @@ -549,6 +550,9 @@ class Generator: text = f"{self.identifier_start}{text}{self.identifier_end}" return text + def national_sql(self, expression: exp.National) -> str: + return f"N{self.sql(expression, 'this')}" + def partition_sql(self, expression: exp.Partition) -> str: keys = csv( *[ @@ -633,6 +637,9 @@ class Generator: def introducer_sql(self, expression: exp.Introducer) -> str: return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" + def pseudotype_sql(self, expression: exp.PseudoType) -> str: + return expression.name.upper() + def rowformatdelimitedproperty_sql(self, expression: exp.RowFormatDelimitedProperty) -> str: fields = expression.args.get("fields") fields = f" FIELDS TERMINATED BY {fields}" if fields else "" @@ -793,19 +800,17 @@ class Generator: if isinstance(expression.this, exp.Subquery): return f"LATERAL {this}" - alias = expression.args["alias"] - table = alias.name - columns = self.expressions(alias, key="columns", flat=True) - if expression.args.get("view"): - table = f" {table}" if table else table + alias = expression.args["alias"] + columns = self.expressions(alias, key="columns", flat=True) + table = f" {alias.name}" if alias.name else "" columns = f" AS {columns}" if columns else "" op_sql = self.seg(f"LATERAL VIEW{' OUTER' if expression.args.get('outer') else ''}") return f"{op_sql}{self.sep()}{this}{table}{columns}" - table = f" AS {table}" if table else table - columns = f"({columns})" if columns else "" - return f"LATERAL {this}{table}{columns}" + alias = self.sql(expression, "alias") + alias = f" AS {alias}" if alias else "" + return f"LATERAL {this}{alias}" def limit_sql(self, expression: exp.Limit) -> str: this = self.sql(expression, "this") @@ -891,13 +896,15 @@ class Generator: def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: return csv( *sqls, - *[self.sql(sql) for sql in expression.args.get("joins", [])], - *[self.sql(sql) for sql in expression.args.get("laterals", [])], + *[self.sql(sql) for sql in expression.args.get("joins") or []], + *[self.sql(sql) for sql in expression.args.get("laterals") or []], self.sql(expression, "where"), self.sql(expression, "group"), self.sql(expression, "having"), self.sql(expression, "qualify"), - self.sql(expression, "window"), + self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True) + if expression.args.get("windows") + else "", self.sql(expression, "distribute"), self.sql(expression, "sort"), self.sql(expression, "cluster"), @@ -1008,11 +1015,7 @@ class Generator: spec_sql = " " + self.window_spec_sql(spec) if spec else "" alias = self.sql(expression, "alias") - - if expression.arg_key == "window": - this = this = f"{self.seg('WINDOW')} {this} AS" - else: - this = f"{this} OVER" + this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}" if not partition and not order and not spec and alias: return f"{this} {alias}" @@ -1141,9 +1144,11 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" unit = self.sql(expression, "unit") unit = f" {unit}" if unit else "" - return f"INTERVAL {self.sql(expression, 'this')}{unit}" + return f"INTERVAL{this}{unit}" def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") @@ -1245,6 +1250,43 @@ class Generator: savepoint = f" TO {savepoint}" if savepoint else "" return f"ROLLBACK{savepoint}" + def altercolumn_sql(self, expression: exp.AlterColumn) -> str: + this = self.sql(expression, "this") + + dtype = self.sql(expression, "dtype") + if dtype: + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + using = self.sql(expression, "using") + using = f" USING {using}" if using else "" + return f"ALTER COLUMN {this} TYPE {dtype}{collate}{using}" + + default = self.sql(expression, "default") + if default: + return f"ALTER COLUMN {this} SET DEFAULT {default}" + + if not expression.args.get("drop"): + self.unsupported("Unsupported ALTER COLUMN syntax") + + return f"ALTER COLUMN {this} DROP DEFAULT" + + def altertable_sql(self, expression: exp.AlterTable) -> str: + actions = expression.args["actions"] + + if isinstance(actions[0], exp.ColumnDef): + actions = self.expressions(expression, "actions", prefix="ADD COLUMN ") + elif isinstance(actions[0], exp.Schema): + actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") + elif isinstance(actions[0], exp.Drop): + actions = self.expressions(expression, "actions") + elif isinstance(actions[0], exp.AlterColumn): + actions = self.sql(actions[0]) + else: + self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") + + exists = " IF EXISTS" if expression.args.get("exists") else "" + return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1327,6 +1369,9 @@ class Generator: def or_sql(self, expression: exp.Or) -> str: return self.connector_sql(expression, "OR") + def slice_sql(self, expression: exp.Slice) -> str: + return self.binary(expression, ":") + def sub_sql(self, expression: exp.Sub) -> str: return self.binary(expression, "-") @@ -1369,6 +1414,7 @@ class Generator: flat: bool = False, indent: bool = True, sep: str = ", ", + prefix: str = "", ) -> str: expressions = expression.args.get(key or "expressions") @@ -1391,11 +1437,13 @@ class Generator: if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") + result_sqls.append(f"{sep if i > 0 else pad}{prefix}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") + result_sqls.append( + f"{prefix}{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}" + ) else: - result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{prefix}{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sql = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sql, skip_first=False) if indent else result_sql |