diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 120 |
1 files changed, 69 insertions, 51 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index ffb34eb..47774fc 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,19 +1,16 @@ from __future__ import annotations import logging -import re import typing as t from sqlglot import exp -from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors +from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages from sqlglot.helper import apply_index_offset, csv from sqlglot.time import format_time from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") -NEWLINE_RE = re.compile("\r\n?|\n") - class Generator: """ @@ -58,11 +55,11 @@ class Generator: """ TRANSFORMS = { - exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}", exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})", exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})", exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})", + exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}", exp.LanguageProperty: lambda self, e: self.naked_property(e), exp.LocationProperty: lambda self, e: self.naked_property(e), exp.ReturnsProperty: lambda self, e: self.naked_property(e), @@ -97,16 +94,17 @@ class Generator: exp.DistStyleProperty, exp.DistKeyProperty, exp.SortKeyProperty, + exp.LikeProperty, } WITH_PROPERTIES = { - exp.AnonymousProperty, + exp.Property, exp.FileFormatProperty, exp.PartitionedByProperty, exp.TableFormatProperty, } - WITH_SEPARATED_COMMENTS = (exp.Select,) + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) __slots__ = ( "time_mapping", @@ -211,7 +209,7 @@ class Generator: for msg in self.unsupported_messages: logger.warning(msg) elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages: - raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported)) + raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported)) return sql @@ -226,25 +224,24 @@ class Generator: def seg(self, sql, sep=" "): return f"{self.sep(sep)}{sql}" - def maybe_comment(self, sql, expression, single_line=False): - comment = expression.comment if self._comments else None - - if not comment: - return sql - + def pad_comment(self, comment): comment = " " + comment if comment[0].strip() else comment comment = comment + " " if comment[-1].strip() else comment + return comment - if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"/*{comment}*/{self.sep()}{sql}" + def maybe_comment(self, sql, expression): + comments = expression.comments if self._comments else None - if not self.pretty: - return f"{sql} /*{comment}*/" + if not comments: + return sql + + sep = "\n" if self.pretty else " " + comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments) - if not NEWLINE_RE.search(comment): - return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/" + if isinstance(expression, self.WITH_SEPARATED_COMMENTS): + return f"{comments}{self.sep()}{sql}" - return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/" + return f"{sql} {comments}" def wrap(self, expression): this_sql = self.indent( @@ -387,8 +384,11 @@ class Generator: def notnullcolumnconstraint_sql(self, _): return "NOT NULL" - def primarykeycolumnconstraint_sql(self, _): - return "PRIMARY KEY" + def primarykeycolumnconstraint_sql(self, expression): + desc = expression.args.get("desc") + if desc is not None: + return f"PRIMARY KEY{' DESC' if desc else ' ASC'}" + return f"PRIMARY KEY" def uniquecolumnconstraint_sql(self, _): return "UNIQUE" @@ -546,36 +546,33 @@ class Generator: def root_properties(self, properties): if properties.expressions: - return self.sep() + self.expressions( - properties, - indent=False, - sep=" ", - ) + return self.sep() + self.expressions(properties, indent=False, sep=" ") return "" def properties(self, properties, prefix="", sep=", "): if properties.expressions: - expressions = self.expressions( - properties, - sep=sep, - indent=False, - ) + expressions = self.expressions(properties, sep=sep, indent=False) return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}" return "" def with_properties(self, properties): - return self.properties( - properties, - prefix="WITH", - ) + return self.properties(properties, prefix="WITH") def property_sql(self, expression): - if isinstance(expression.this, exp.Literal): - key = expression.this.this - else: - key = expression.name - value = self.sql(expression, "value") - return f"{key}={value}" + property_cls = expression.__class__ + if property_cls == exp.Property: + return f"{expression.name}={self.sql(expression, 'value')}" + + property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls) + if not property_name: + self.unsupported(f"Unsupported property {property_name}") + + return f"{property_name}={self.sql(expression, 'this')}" + + def likeproperty_sql(self, expression): + options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions) + options = f" {options}" if options else "" + return f"LIKE {self.sql(expression, 'this')}{options}" def insert_sql(self, expression): overwrite = expression.args.get("overwrite") @@ -700,6 +697,11 @@ class Generator: def var_sql(self, expression): return self.sql(expression, "this") + def into_sql(self, expression): + temporary = " TEMPORARY" if expression.args.get("temporary") else "" + unlogged = " UNLOGGED" if expression.args.get("unlogged") else "" + return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}" + def from_sql(self, expression): expressions = self.expressions(expression, flat=True) return f"{self.seg('FROM')} {expressions}" @@ -883,6 +885,7 @@ class Generator: sql = self.query_modifiers( expression, f"SELECT{hint}{distinct}{expressions}", + self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) return self.prepend_ctes(expression, sql) @@ -1061,6 +1064,11 @@ class Generator: else: return f"TRIM({target})" + def concat_sql(self, expression): + if len(expression.expressions) == 1: + return self.sql(expression.expressions[0]) + return self.function_fallback_sql(expression) + def check_sql(self, expression): this = self.sql(expression, key="this") return f"CHECK ({this})" @@ -1125,7 +1133,10 @@ class Generator: return self.prepend_ctes(expression, sql) def neg_sql(self, expression): - return f"-{self.sql(expression, 'this')}" + # This makes sure we don't convert "- - 5" to "--5", which is a comment + this_sql = self.sql(expression, "this") + sep = " " if this_sql[0] == "-" else "" + return f"-{sep}{this_sql}" def not_sql(self, expression): return f"NOT {self.sql(expression, 'this')}" @@ -1191,8 +1202,12 @@ class Generator: def transaction_sql(self, *_): return "BEGIN" - def commit_sql(self, *_): - return "COMMIT" + def commit_sql(self, expression): + chain = expression.args.get("chain") + if chain is not None: + chain = " AND CHAIN" if chain else " AND NO CHAIN" + + return f"COMMIT{chain or ''}" def rollback_sql(self, expression): savepoint = expression.args.get("savepoint") @@ -1334,15 +1349,15 @@ class Generator: result_sqls = [] for i, e in enumerate(expressions): sql = self.sql(e, comment=False) - comment = self.maybe_comment("", e, single_line=True) + comments = self.maybe_comment("", e) if self.pretty: if self._leading_comma: - result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}") + result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}") else: - result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}") + result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}") else: - result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}") + result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}") result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls) return self.indent(result_sqls, skip_first=False) if indent else result_sqls @@ -1354,7 +1369,10 @@ class Generator: return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}" def naked_property(self, expression): - return f"{expression.name} {self.sql(expression, 'value')}" + property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__) + if not property_name: + self.unsupported(f"Unsupported property {expression.__class__.__name__}") + return f"{property_name} {self.sql(expression, 'this')}" def set_operation(self, expression, op): this = self.sql(expression, "this") |