summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:32 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2022-12-02 09:16:32 +0000
commitb3c7fe6a73484a4d2177c30f951cd11a4916ed56 (patch)
tree7192898cb782bbb0b9b13bd8d6341fe4434f0f31 /sqlglot/generator.py
parentReleasing debian version 10.0.8-1. (diff)
downloadsqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.tar.xz
sqlglot-b3c7fe6a73484a4d2177c30f951cd11a4916ed56.zip
Merging upstream version 10.1.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py120
1 files changed, 69 insertions, 51 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index ffb34eb..47774fc 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,19 +1,16 @@
from __future__ import annotations
import logging
-import re
import typing as t
from sqlglot import exp
-from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
+from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages
from sqlglot.helper import apply_index_offset, csv
from sqlglot.time import format_time
from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
-NEWLINE_RE = re.compile("\r\n?|\n")
-
class Generator:
"""
@@ -58,11 +55,11 @@ class Generator:
"""
TRANSFORMS = {
- exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'value')}",
exp.DateAdd: lambda self, e: f"DATE_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.DateDiff: lambda self, e: f"DATEDIFF({self.format_args(e.this, e.expression)})",
exp.TsOrDsAdd: lambda self, e: f"TS_OR_DS_ADD({self.format_args(e.this, e.expression, e.args.get('unit'))})",
exp.VarMap: lambda self, e: f"MAP({self.format_args(e.args['keys'], e.args['values'])})",
+ exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args['default'] else ''}CHARACTER SET={self.sql(e, 'this')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.ReturnsProperty: lambda self, e: self.naked_property(e),
@@ -97,16 +94,17 @@ class Generator:
exp.DistStyleProperty,
exp.DistKeyProperty,
exp.SortKeyProperty,
+ exp.LikeProperty,
}
WITH_PROPERTIES = {
- exp.AnonymousProperty,
+ exp.Property,
exp.FileFormatProperty,
exp.PartitionedByProperty,
exp.TableFormatProperty,
}
- WITH_SEPARATED_COMMENTS = (exp.Select,)
+ WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
__slots__ = (
"time_mapping",
@@ -211,7 +209,7 @@ class Generator:
for msg in self.unsupported_messages:
logger.warning(msg)
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
- raise UnsupportedError(concat_errors(self.unsupported_messages, self.max_unsupported))
+ raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
return sql
@@ -226,25 +224,24 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
- def maybe_comment(self, sql, expression, single_line=False):
- comment = expression.comment if self._comments else None
-
- if not comment:
- return sql
-
+ def pad_comment(self, comment):
comment = " " + comment if comment[0].strip() else comment
comment = comment + " " if comment[-1].strip() else comment
+ return comment
- if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
- return f"/*{comment}*/{self.sep()}{sql}"
+ def maybe_comment(self, sql, expression):
+ comments = expression.comments if self._comments else None
- if not self.pretty:
- return f"{sql} /*{comment}*/"
+ if not comments:
+ return sql
+
+ sep = "\n" if self.pretty else " "
+ comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments)
- if not NEWLINE_RE.search(comment):
- return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
+ if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
+ return f"{comments}{self.sep()}{sql}"
- return f"/*{comment}*/\n{sql}" if sql else f" /*{comment}*/"
+ return f"{sql} {comments}"
def wrap(self, expression):
this_sql = self.indent(
@@ -387,8 +384,11 @@ class Generator:
def notnullcolumnconstraint_sql(self, _):
return "NOT NULL"
- def primarykeycolumnconstraint_sql(self, _):
- return "PRIMARY KEY"
+ def primarykeycolumnconstraint_sql(self, expression):
+ desc = expression.args.get("desc")
+ if desc is not None:
+ return f"PRIMARY KEY{' DESC' if desc else ' ASC'}"
+ return f"PRIMARY KEY"
def uniquecolumnconstraint_sql(self, _):
return "UNIQUE"
@@ -546,36 +546,33 @@ class Generator:
def root_properties(self, properties):
if properties.expressions:
- return self.sep() + self.expressions(
- properties,
- indent=False,
- sep=" ",
- )
+ return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
def properties(self, properties, prefix="", sep=", "):
if properties.expressions:
- expressions = self.expressions(
- properties,
- sep=sep,
- indent=False,
- )
+ expressions = self.expressions(properties, sep=sep, indent=False)
return f"{self.seg(prefix)}{' ' if prefix else ''}{self.wrap(expressions)}"
return ""
def with_properties(self, properties):
- return self.properties(
- properties,
- prefix="WITH",
- )
+ return self.properties(properties, prefix="WITH")
def property_sql(self, expression):
- if isinstance(expression.this, exp.Literal):
- key = expression.this.this
- else:
- key = expression.name
- value = self.sql(expression, "value")
- return f"{key}={value}"
+ property_cls = expression.__class__
+ if property_cls == exp.Property:
+ return f"{expression.name}={self.sql(expression, 'value')}"
+
+ property_name = exp.Properties.PROPERTY_TO_NAME.get(property_cls)
+ if not property_name:
+ self.unsupported(f"Unsupported property {property_name}")
+
+ return f"{property_name}={self.sql(expression, 'this')}"
+
+ def likeproperty_sql(self, expression):
+ options = " ".join(f"{e.name} {self.sql(e, 'value')}" for e in expression.expressions)
+ options = f" {options}" if options else ""
+ return f"LIKE {self.sql(expression, 'this')}{options}"
def insert_sql(self, expression):
overwrite = expression.args.get("overwrite")
@@ -700,6 +697,11 @@ class Generator:
def var_sql(self, expression):
return self.sql(expression, "this")
+ def into_sql(self, expression):
+ temporary = " TEMPORARY" if expression.args.get("temporary") else ""
+ unlogged = " UNLOGGED" if expression.args.get("unlogged") else ""
+ return f"{self.seg('INTO')}{temporary or unlogged} {self.sql(expression, 'this')}"
+
def from_sql(self, expression):
expressions = self.expressions(expression, flat=True)
return f"{self.seg('FROM')} {expressions}"
@@ -883,6 +885,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
+ self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@@ -1061,6 +1064,11 @@ class Generator:
else:
return f"TRIM({target})"
+ def concat_sql(self, expression):
+ if len(expression.expressions) == 1:
+ return self.sql(expression.expressions[0])
+ return self.function_fallback_sql(expression)
+
def check_sql(self, expression):
this = self.sql(expression, key="this")
return f"CHECK ({this})"
@@ -1125,7 +1133,10 @@ class Generator:
return self.prepend_ctes(expression, sql)
def neg_sql(self, expression):
- return f"-{self.sql(expression, 'this')}"
+ # This makes sure we don't convert "- - 5" to "--5", which is a comment
+ this_sql = self.sql(expression, "this")
+ sep = " " if this_sql[0] == "-" else ""
+ return f"-{sep}{this_sql}"
def not_sql(self, expression):
return f"NOT {self.sql(expression, 'this')}"
@@ -1191,8 +1202,12 @@ class Generator:
def transaction_sql(self, *_):
return "BEGIN"
- def commit_sql(self, *_):
- return "COMMIT"
+ def commit_sql(self, expression):
+ chain = expression.args.get("chain")
+ if chain is not None:
+ chain = " AND CHAIN" if chain else " AND NO CHAIN"
+
+ return f"COMMIT{chain or ''}"
def rollback_sql(self, expression):
savepoint = expression.args.get("savepoint")
@@ -1334,15 +1349,15 @@ class Generator:
result_sqls = []
for i, e in enumerate(expressions):
sql = self.sql(e, comment=False)
- comment = self.maybe_comment("", e, single_line=True)
+ comments = self.maybe_comment("", e)
if self.pretty:
if self._leading_comma:
- result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
+ result_sqls.append(f"{sep if i > 0 else pad}{sql}{comments}")
else:
- result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
+ result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comments}")
else:
- result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
+ result_sqls.append(f"{sql}{comments}{sep if i + 1 < num_sqls else ''}")
result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
return self.indent(result_sqls, skip_first=False) if indent else result_sqls
@@ -1354,7 +1369,10 @@ class Generator:
return f"{self.seg(op)}{self.sep() if expressions_sql else ''}{expressions_sql}"
def naked_property(self, expression):
- return f"{expression.name} {self.sql(expression, 'value')}"
+ property_name = exp.Properties.PROPERTY_TO_NAME.get(expression.__class__)
+ if not property_name:
+ self.unsupported(f"Unsupported property {expression.__class__.__name__}")
+ return f"{property_name} {self.sql(expression, 'this')}"
def set_operation(self, expression, op):
this = self.sql(expression, "this")