diff options
Diffstat (limited to '')
-rw-r--r-- | sqlglot/generator.py | 55 |
1 files changed, 39 insertions, 16 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b7e26bb..0d6778a 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -11,6 +11,9 @@ from sqlglot.helper import apply_index_offset, csv, seq_get from sqlglot.time import format_time from sqlglot.tokens import Tokenizer, TokenType +if t.TYPE_CHECKING: + from sqlglot._typing import E + logger = logging.getLogger("sqlglot") @@ -141,6 +144,9 @@ class Generator: # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH") LIMIT_FETCH = "ALL" + # Whether or not limit and fetch allows expresions or just limits + LIMIT_ONLY_LITERALS = False + # Whether or not a table is allowed to be renamed with a db RENAME_TABLE_WITH_DB = True @@ -341,6 +347,12 @@ class Generator: exp.With, ) + # Expressions that should not have their comments generated in maybe_comment + EXCLUDE_COMMENTS: t.Tuple[t.Type[exp.Expression], ...] = ( + exp.Binary, + exp.Union, + ) + # Expressions that can remain unwrapped when appearing in the context of an INTERVAL UNWRAPPED_INTERVAL_VALUES: t.Tuple[t.Type[exp.Expression], ...] = ( exp.Column, @@ -501,7 +513,7 @@ class Generator: else None ) - if not comments or isinstance(expression, exp.Binary): + if not comments or isinstance(expression, self.EXCLUDE_COMMENTS): return sql comments_sql = " ".join( @@ -879,6 +891,10 @@ class Generator: alias = self.sql(expression, "this") columns = self.expressions(expression, key="columns", flat=True) columns = f"({columns})" if columns else "" + + if not alias and not self.UNNEST_COLUMN_ONLY: + alias = "_t" + return f"{alias}{columns}" def bitstring_sql(self, expression: exp.BitString) -> str: @@ -1611,9 +1627,6 @@ class Generator: def lateral_sql(self, expression: exp.Lateral) -> str: this = self.sql(expression, "this") - if isinstance(expression.this, exp.Subquery): - return f"LATERAL {this}" - if expression.args.get("view"): alias = expression.args["alias"] columns = self.expressions(alias, key="columns", flat=True) @@ -1629,18 +1642,19 @@ class Generator: def limit_sql(self, expression: exp.Limit, top: bool = False) -> str: this = self.sql(expression, "this") args = ", ".join( - sql - for sql in ( - self.sql(expression, "offset"), - self.sql(expression, "expression"), - ) - if sql + self.sql(self._simplify_unless_literal(e) if self.LIMIT_ONLY_LITERALS else e) + for e in (expression.args.get(k) for k in ("offset", "expression")) + if e ) return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args}" def offset_sql(self, expression: exp.Offset) -> str: this = self.sql(expression, "this") - return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + expression = expression.expression + expression = ( + self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression + ) + return f"{this}{self.seg('OFFSET')} {self.sql(expression)}" def setitem_sql(self, expression: exp.SetItem) -> str: kind = self.sql(expression, "kind") @@ -1895,12 +1909,13 @@ class Generator: def schema_sql(self, expression: exp.Schema) -> str: this = self.sql(expression, "this") - this = f"{this} " if this else "" sql = self.schema_columns_sql(expression) - return f"{this}{sql}" + return f"{this} {sql}" if this and sql else this or sql def schema_columns_sql(self, expression: exp.Schema) -> str: - return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + if expression.expressions: + return f"({self.sep('')}{self.expressions(expression)}{self.seg(')', sep='')}" + return "" def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) @@ -2708,8 +2723,8 @@ class Generator: self.unsupported(f"Unsupported property {expression.__class__.__name__}") return f"{property_name} {self.sql(expression, 'this')}" - def set_operation(self, expression: exp.Expression, op: str) -> str: - this = self.sql(expression, "this") + def set_operation(self, expression: exp.Union, op: str) -> str: + this = self.maybe_comment(self.sql(expression, "this"), comments=expression.comments) op = self.seg(op) return self.query_modifiers( expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" @@ -2912,6 +2927,14 @@ class Generator: parameters = self.sql(expression, "params_struct") return self.func("PREDICT", model, table, parameters or None) + def _simplify_unless_literal(self, expression: E) -> E: + if not isinstance(expression, exp.Literal): + from sqlglot.optimizer.simplify import simplify + + expression = simplify(expression.copy()) + + return expression + def cached_generator( cache: t.Optional[t.Dict[int, str]] = None |