summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py214
1 files changed, 147 insertions, 67 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index ca14425..11d9073 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,4 +1,8 @@
+from __future__ import annotations
+
import logging
+import re
+import typing as t
from sqlglot import exp
from sqlglot.errors import ErrorLevel, UnsupportedError, concat_errors
@@ -8,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
+NEWLINE_RE = re.compile("\r\n?|\n")
+
class Generator:
"""
@@ -47,8 +53,7 @@ class Generator:
The default is on the smaller end because the length only represents a segment and not the true
line length.
Default: 80
- annotations: Whether or not to show annotations in the SQL when `pretty` is True.
- Annotations can only be shown in pretty mode otherwise they may clobber resulting sql.
+ comments: Whether or not to preserve comments in the ouput SQL code.
Default: True
"""
@@ -65,14 +70,16 @@ class Generator:
exp.VolatilityProperty: lambda self, e: self.sql(e.name),
}
- # whether 'CREATE ... TRANSIENT ... TABLE' is allowed
- # can override in dialects
+ # Whether 'CREATE ... TRANSIENT ... TABLE' is allowed
CREATE_TRANSIENT = False
- # whether or not null ordering is supported in order by
+
+ # Whether or not null ordering is supported in order by
NULL_ORDERING_SUPPORTED = True
- # always do union distinct or union all
+
+ # Always do union distinct or union all
EXPLICIT_UNION = False
- # wrap derived values in parens, usually standard but spark doesn't support it
+
+ # Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
TYPE_MAPPING = {
@@ -80,7 +87,7 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR",
}
- TOKEN_MAPPING = {}
+ TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@@ -96,6 +103,8 @@ class Generator:
exp.TableFormatProperty,
}
+ WITH_SEPARATED_COMMENTS = (exp.Select,)
+
__slots__ = (
"time_mapping",
"time_trie",
@@ -122,7 +131,7 @@ class Generator:
"_escaped_quote_end",
"_leading_comma",
"_max_text_width",
- "_annotations",
+ "_comments",
)
def __init__(
@@ -148,7 +157,7 @@ class Generator:
max_unsupported=3,
leading_comma=False,
max_text_width=80,
- annotations=True,
+ comments=True,
):
import sqlglot
@@ -177,7 +186,7 @@ class Generator:
self._escaped_quote_end = self.escape + self.quote_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
- self._annotations = annotations
+ self._comments = comments
def generate(self, expression):
"""
@@ -204,7 +213,6 @@ class Generator:
return sql
def unsupported(self, message):
-
if self.unsupported_level == ErrorLevel.IMMEDIATE:
raise UnsupportedError(message)
self.unsupported_messages.append(message)
@@ -215,9 +223,31 @@ class Generator:
def seg(self, sql, sep=" "):
return f"{self.sep(sep)}{sql}"
+ def maybe_comment(self, sql, expression, single_line=False):
+ comment = expression.comment if self._comments else None
+
+ if not comment:
+ return sql
+
+ comment = " " + comment if comment[0].strip() else comment
+ comment = comment + " " if comment[-1].strip() else comment
+
+ if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
+ return f"/*{comment}*/{self.sep()}{sql}"
+
+ if not self.pretty:
+ return f"{sql} /*{comment}*/"
+
+ if not NEWLINE_RE.search(comment):
+ return f"{sql} --{comment.rstrip()}" if single_line else f"{sql} /*{comment}*/"
+
+ return f"/*{comment}*/\n{sql}"
+
def wrap(self, expression):
this_sql = self.indent(
- self.sql(expression) if isinstance(expression, (exp.Select, exp.Union)) else self.sql(expression, "this"),
+ self.sql(expression)
+ if isinstance(expression, (exp.Select, exp.Union))
+ else self.sql(expression, "this"),
level=1,
pad=0,
)
@@ -251,7 +281,7 @@ class Generator:
for i, line in enumerate(lines)
)
- def sql(self, expression, key=None):
+ def sql(self, expression, key=None, comment=True):
if not expression:
return ""
@@ -264,29 +294,24 @@ class Generator:
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
- return transform(self, expression)
- if transform:
- return transform
-
- if not isinstance(expression, exp.Expression):
+ sql = transform(self, expression)
+ elif transform:
+ sql = transform
+ elif isinstance(expression, exp.Expression):
+ exp_handler_name = f"{expression.key}_sql"
+
+ if hasattr(self, exp_handler_name):
+ sql = getattr(self, exp_handler_name)(expression)
+ elif isinstance(expression, exp.Func):
+ sql = self.function_fallback_sql(expression)
+ elif isinstance(expression, exp.Property):
+ sql = self.property_sql(expression)
+ else:
+ raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
+ else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- exp_handler_name = f"{expression.key}_sql"
- if hasattr(self, exp_handler_name):
- return getattr(self, exp_handler_name)(expression)
-
- if isinstance(expression, exp.Func):
- return self.function_fallback_sql(expression)
-
- if isinstance(expression, exp.Property):
- return self.property_sql(expression)
-
- raise ValueError(f"Unsupported expression type {expression.__class__.__name__}")
-
- def annotation_sql(self, expression):
- if self._annotations and self.pretty:
- return f"{self.sql(expression, 'expression')} # {expression.name}"
- return self.sql(expression, "expression")
+ return self.maybe_comment(sql, expression) if self._comments and comment else sql
def uncache_sql(self, expression):
table = self.sql(expression, "this")
@@ -371,7 +396,9 @@ class Generator:
expression_sql = self.sql(expression, "expression")
expression_sql = f"AS{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
- transient = " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
+ transient = (
+ " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
+ )
replace = " OR REPLACE" if expression.args.get("replace") else ""
exists_sql = " IF NOT EXISTS" if expression.args.get("exists") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
@@ -434,7 +461,9 @@ class Generator:
def delete_sql(self, expression):
this = self.sql(expression, "this")
using_sql = (
- f" USING {self.expressions(expression, 'using', sep=', USING ')}" if expression.args.get("using") else ""
+ f" USING {self.expressions(expression, 'using', sep=', USING ')}"
+ if expression.args.get("using")
+ else ""
)
where_sql = self.sql(expression, "where")
sql = f"DELETE FROM {this}{using_sql}{where_sql}"
@@ -481,15 +510,18 @@ class Generator:
return f"{this} ON {table} {columns}"
def identifier_sql(self, expression):
- value = expression.name
- value = value.lower() if self.normalize else value
+ text = expression.name
+ text = text.lower() if self.normalize else text
if expression.args.get("quoted") or self.identify:
- return f"{self.identifier_start}{value}{self.identifier_end}"
- return value
+ text = f"{self.identifier_start}{text}{self.identifier_end}"
+ return text
def partition_sql(self, expression):
keys = csv(
- *[f"{k.args['this']}='{v.args['this']}'" if v else k.args["this"] for k, v in expression.args.get("this")]
+ *[
+ f"""{prop.name}='{prop.text("value")}'""" if prop.text("value") else prop.name
+ for prop in expression.this
+ ]
)
return f"PARTITION({keys})"
@@ -504,9 +536,9 @@ class Generator:
elif p_class in self.ROOT_PROPERTIES:
root_properties.append(p)
- return self.root_properties(exp.Properties(expressions=root_properties)) + self.with_properties(
- exp.Properties(expressions=with_properties)
- )
+ return self.root_properties(
+ exp.Properties(expressions=root_properties)
+ ) + self.with_properties(exp.Properties(expressions=with_properties))
def root_properties(self, properties):
if properties.expressions:
@@ -551,7 +583,9 @@ class Generator:
this = f"{this}{self.sql(expression, 'this')}"
exists = " IF EXISTS " if expression.args.get("exists") else " "
- partition_sql = self.sql(expression, "partition") if expression.args.get("partition") else ""
+ partition_sql = (
+ self.sql(expression, "partition") if expression.args.get("partition") else ""
+ )
expression_sql = self.sql(expression, "expression")
sep = self.sep() if partition_sql else ""
sql = f"INSERT {this}{exists}{partition_sql}{sep}{expression_sql}"
@@ -669,7 +703,9 @@ class Generator:
def group_sql(self, expression):
group_by = self.op_expressions("GROUP BY", expression)
grouping_sets = self.expressions(expression, key="grouping_sets", indent=False)
- grouping_sets = f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
+ grouping_sets = (
+ f"{self.seg('GROUPING SETS')} {self.wrap(grouping_sets)}" if grouping_sets else ""
+ )
cube = self.expressions(expression, key="cube", indent=False)
cube = f"{self.seg('CUBE')} {self.wrap(cube)}" if cube else ""
rollup = self.expressions(expression, key="rollup", indent=False)
@@ -711,10 +747,10 @@ class Generator:
this_sql = self.sql(expression, "this")
return f"{expression_sql}{op_sql} {this_sql}{on_sql}"
- def lambda_sql(self, expression):
+ def lambda_sql(self, expression, arrow_sep="->"):
args = self.expressions(expression, flat=True)
args = f"({args})" if len(args.split(",")) > 1 else args
- return self.no_identify(lambda: f"{args} -> {self.sql(expression, 'this')}")
+ return self.no_identify(lambda: f"{args} {arrow_sep} {self.sql(expression, 'this')}")
def lateral_sql(self, expression):
this = self.sql(expression, "this")
@@ -748,7 +784,7 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
- return f"{self.quote_start}{text}{self.quote_end}"
+ text = f"{self.quote_start}{text}{self.quote_end}"
return text
def loaddata_sql(self, expression):
@@ -796,13 +832,21 @@ class Generator:
sort_order = " DESC" if desc else ""
nulls_sort_change = ""
- if nulls_first and ((asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last):
+ if nulls_first and (
+ (asc and nulls_are_large) or (desc and nulls_are_small) or nulls_are_last
+ ):
nulls_sort_change = " NULLS FIRST"
- elif nulls_last and ((asc and nulls_are_small) or (desc and nulls_are_large)) and not nulls_are_last:
+ elif (
+ nulls_last
+ and ((asc and nulls_are_small) or (desc and nulls_are_large))
+ and not nulls_are_last
+ ):
nulls_sort_change = " NULLS LAST"
if nulls_sort_change and not self.NULL_ORDERING_SUPPORTED:
- self.unsupported("Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect")
+ self.unsupported(
+ "Sorting in an ORDER BY on NULLS FIRST/NULLS LAST is not supported by this dialect"
+ )
nulls_sort_change = ""
return f"{self.sql(expression, 'this')}{sort_order}{nulls_sort_change}"
@@ -835,7 +879,7 @@ class Generator:
sql = self.query_modifiers(
expression,
f"SELECT{hint}{distinct}{expressions}",
- self.sql(expression, "from"),
+ self.sql(expression, "from", comment=False),
)
return self.prepend_ctes(expression, sql)
@@ -858,6 +902,13 @@ class Generator:
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"
+ def sessionparameter_sql(self, expression):
+ this = self.sql(expression, "this")
+ kind = expression.text("kind")
+ if kind:
+ kind = f"{kind}."
+ return f"@@{kind}{this}"
+
def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?"
@@ -931,7 +982,10 @@ class Generator:
def window_spec_sql(self, expression):
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
- end = csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ") or "CURRENT ROW"
+ end = (
+ csv(self.sql(expression, "end"), self.sql(expression, "end_side"), sep=" ")
+ or "CURRENT ROW"
+ )
return f"{kind} BETWEEN {start} AND {end}"
def withingroup_sql(self, expression):
@@ -1020,7 +1074,9 @@ class Generator:
return f"UNIQUE ({columns})"
def if_sql(self, expression):
- return self.case_sql(exp.Case(ifs=[expression.copy()], default=expression.args.get("false")))
+ return self.case_sql(
+ exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
+ )
def in_sql(self, expression):
query = expression.args.get("query")
@@ -1196,6 +1252,12 @@ class Generator:
def neq_sql(self, expression):
return self.binary(expression, "<>")
+ def nullsafeeq_sql(self, expression):
+ return self.binary(expression, "IS NOT DISTINCT FROM")
+
+ def nullsafeneq_sql(self, expression):
+ return self.binary(expression, "IS DISTINCT FROM")
+
def or_sql(self, expression):
return self.connector_sql(expression, "OR")
@@ -1205,6 +1267,9 @@ class Generator:
def trycast_sql(self, expression):
return f"TRY_CAST({self.sql(expression, 'this')} AS {self.sql(expression, 'to')})"
+ def use_sql(self, expression):
+ return f"USE {self.sql(expression, 'this')}"
+
def binary(self, expression, op):
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
@@ -1240,17 +1305,27 @@ class Generator:
if flat:
return sep.join(self.sql(e) for e in expressions)
- sql = (self.sql(e) for e in expressions)
- # the only time leading_comma changes the output is if pretty print is enabled
- if self._leading_comma and self.pretty:
- pad = " " * self.pad
- expressions = "\n".join(f"{sep}{s}" if i > 0 else f"{pad}{s}" for i, s in enumerate(sql))
- else:
- expressions = self.sep(sep).join(sql)
+ num_sqls = len(expressions)
+
+ # These are calculated once in case we have the leading_comma / pretty option set, correspondingly
+ pad = " " * self.pad
+ stripped_sep = sep.strip()
- if indent:
- return self.indent(expressions, skip_first=False)
- return expressions
+ result_sqls = []
+ for i, e in enumerate(expressions):
+ sql = self.sql(e, comment=False)
+ comment = self.maybe_comment("", e, single_line=True)
+
+ if self.pretty:
+ if self._leading_comma:
+ result_sqls.append(f"{sep if i > 0 else pad}{sql}{comment}")
+ else:
+ result_sqls.append(f"{sql}{stripped_sep if i + 1 < num_sqls else ''}{comment}")
+ else:
+ result_sqls.append(f"{sql}{comment}{sep if i + 1 < num_sqls else ''}")
+
+ result_sqls = "\n".join(result_sqls) if self.pretty else "".join(result_sqls)
+ return self.indent(result_sqls, skip_first=False) if indent else result_sqls
def op_expressions(self, op, expression, flat=False):
expressions_sql = self.expressions(expression, flat=flat)
@@ -1264,7 +1339,9 @@ class Generator:
def set_operation(self, expression, op):
this = self.sql(expression, "this")
op = self.seg(op)
- return self.query_modifiers(expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}")
+ return self.query_modifiers(
+ expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
+ )
def token_sql(self, token_type):
return self.TOKEN_MAPPING.get(token_type, token_type.name)
@@ -1283,3 +1360,6 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
return f"{this}({expressions})"
+
+ def kwarg_sql(self, expression):
+ return self.binary(expression, "=>")