summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py205
1 files changed, 152 insertions, 53 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 8a49d55..bd12d54 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -76,11 +76,13 @@ class Generator:
exp.SqlSecurityProperty: lambda self, e: f"SQL SECURITY {'DEFINER' if e.args.get('definer') else 'INVOKER'}",
exp.TemporaryProperty: lambda self, e: f"{'GLOBAL ' if e.args.get('global_') else ''}TEMPORARY",
exp.TransientProperty: lambda self, e: "TRANSIENT",
- exp.VolatilityProperty: lambda self, e: e.name,
+ exp.StabilityProperty: lambda self, e: e.name,
+ exp.VolatileProperty: lambda self, e: "VOLATILE",
exp.WithJournalTableProperty: lambda self, e: f"WITH JOURNAL TABLE={self.sql(e, 'this')}",
exp.CaseSpecificColumnConstraint: lambda self, e: f"{'NOT ' if e.args.get('not_') else ''}CASESPECIFIC",
exp.CharacterSetColumnConstraint: lambda self, e: f"CHARACTER SET {self.sql(e, 'this')}",
exp.DateFormatColumnConstraint: lambda self, e: f"FORMAT {self.sql(e, 'this')}",
+ exp.OnUpdateColumnConstraint: lambda self, e: f"ON UPDATE {self.sql(e, 'this')}",
exp.UppercaseColumnConstraint: lambda self, e: f"UPPERCASE",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.PathColumnConstraint: lambda self, e: f"PATH {self.sql(e, 'this')}",
@@ -110,8 +112,19 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
- # Whether or not limit and fetch are supported
- # "ALL", "LIMIT", "FETCH"
+ # Whether or not the INTERVAL expression works only with values like '1 day'
+ SINGLE_STRING_INTERVAL = False
+
+ # Whether or not the plural form of date parts like day (i.e. "days") is supported in INTERVALs
+ INTERVAL_ALLOWS_PLURAL_FORM = True
+
+ # Whether or not the TABLESAMPLE clause supports a method name, like BERNOULLI
+ TABLESAMPLE_WITH_METHOD = True
+
+ # Whether or not to treat the number in TABLESAMPLE (50) as a percentage
+ TABLESAMPLE_SIZE_IS_PERCENT = False
+
+ # Whether or not limit and fetch are supported (possible values: "ALL", "LIMIT", "FETCH")
LIMIT_FETCH = "ALL"
TYPE_MAPPING = {
@@ -129,6 +142,18 @@ class Generator:
"replace": "REPLACE",
}
+ TIME_PART_SINGULARS = {
+ "microseconds": "microsecond",
+ "seconds": "second",
+ "minutes": "minute",
+ "hours": "hour",
+ "days": "day",
+ "weeks": "week",
+ "months": "month",
+ "quarters": "quarter",
+ "years": "year",
+ }
+
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
STRUCT_DELIMITER = ("<", ">")
@@ -168,6 +193,7 @@ class Generator:
exp.PartitionedByProperty: exp.Properties.Location.POST_WITH,
exp.Property: exp.Properties.Location.POST_WITH,
exp.ReturnsProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.RowFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatDelimitedProperty: exp.Properties.Location.POST_SCHEMA,
exp.RowFormatSerdeProperty: exp.Properties.Location.POST_SCHEMA,
exp.SchemaCommentProperty: exp.Properties.Location.POST_SCHEMA,
@@ -175,15 +201,22 @@ class Generator:
exp.SetProperty: exp.Properties.Location.POST_CREATE,
exp.SortKeyProperty: exp.Properties.Location.POST_SCHEMA,
exp.SqlSecurityProperty: exp.Properties.Location.POST_CREATE,
+ exp.StabilityProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_WITH,
exp.TemporaryProperty: exp.Properties.Location.POST_CREATE,
exp.TransientProperty: exp.Properties.Location.POST_CREATE,
- exp.VolatilityProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.VolatileProperty: exp.Properties.Location.POST_CREATE,
exp.WithDataProperty: exp.Properties.Location.POST_EXPRESSION,
exp.WithJournalTableProperty: exp.Properties.Location.POST_NAME,
}
- WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
+ JOIN_HINTS = True
+ TABLE_HINTS = True
+
+ RESERVED_KEYWORDS: t.Set[str] = set()
+ WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.With)
+ UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren, exp.Column)
+
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
@@ -322,10 +355,15 @@ class Generator:
comment = comment + " " if comment[-1].strip() else comment
return comment
- def maybe_comment(self, sql: str, expression: exp.Expression) -> str:
- comments = expression.comments if self._comments else None
+ def maybe_comment(
+ self,
+ sql: str,
+ expression: t.Optional[exp.Expression] = None,
+ comments: t.Optional[t.List[str]] = None,
+ ) -> str:
+ comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore
- if not comments:
+ if not comments or isinstance(expression, exp.Binary):
return sql
sep = "\n" if self.pretty else " "
@@ -621,7 +659,6 @@ class Generator:
replace = " OR REPLACE" if expression.args.get("replace") else ""
unique = " UNIQUE" if expression.args.get("unique") else ""
- volatile = " VOLATILE" if expression.args.get("volatile") else ""
postcreate_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_CREATE):
@@ -632,7 +669,7 @@ class Generator:
wrapped=False,
)
- modifiers = "".join((replace, unique, volatile, postcreate_props_sql))
+ modifiers = "".join((replace, unique, postcreate_props_sql))
postexpression_props_sql = ""
if properties_locs.get(exp.Properties.Location.POST_EXPRESSION):
@@ -684,6 +721,9 @@ class Generator:
def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
+ def bytestring_sql(self, expression: exp.ByteString) -> str:
+ return self.sql(expression, "this")
+
def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
@@ -695,9 +735,7 @@ class Generator:
nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
if expression.args.get("values") is not None:
delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
- values = (
- f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
- )
+ values = f"{delimiters[0]}{self.expressions(expression, key='values')}{delimiters[1]}"
else:
nested = f"({interior})"
@@ -713,7 +751,7 @@ class Generator:
this = self.sql(expression, "this")
this = f" FROM {this}" if this else ""
using_sql = (
- f" USING {self.expressions(expression, 'using', sep=', USING ')}"
+ f" USING {self.expressions(expression, key='using', sep=', USING ')}"
if expression.args.get("using")
else ""
)
@@ -730,7 +768,10 @@ class Generator:
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
+ purge = " PURGE" if expression.args.get("purge") else ""
+ return (
+ f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}{purge}"
+ )
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@@ -746,7 +787,10 @@ class Generator:
direction = f" {direction.upper()}" if direction else ""
count = expression.args.get("count")
count = f" {count}" if count else ""
- return f"{self.seg('FETCH')}{direction}{count} ROWS ONLY"
+ if expression.args.get("percent"):
+ count = f"{count} PERCENT"
+ with_ties_or_only = "WITH TIES" if expression.args.get("with_ties") else "ONLY"
+ return f"{self.seg('FETCH')}{direction}{count} ROWS {with_ties_or_only}"
def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
@@ -766,12 +810,24 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
- text = text.lower() if self.normalize and not expression.quoted else text
+ lower = text.lower()
+ text = lower if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
- if expression.quoted or should_identify(text, self.identify):
+ if (
+ expression.quoted
+ or should_identify(text, self.identify)
+ or lower in self.RESERVED_KEYWORDS
+ ):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
+ def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
+ input_format = self.sql(expression, "input_format")
+ input_format = f"INPUTFORMAT {input_format}" if input_format else ""
+ output_format = self.sql(expression, "output_format")
+ output_format = f"OUTPUTFORMAT {output_format}" if output_format else ""
+ return self.sep().join((input_format, output_format))
+
def national_sql(self, expression: exp.National) -> str:
return f"N{self.sql(expression, 'this')}"
@@ -984,9 +1040,10 @@ class Generator:
self.sql(expression, "partition") if expression.args.get("partition") else ""
)
expression_sql = self.sql(expression, "expression")
+ conflict = self.sql(expression, "conflict")
returning = self.sql(expression, "returning")
sep = self.sep() if partition_sql else ""
- sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{returning}"
+ sql = f"INSERT{alternative}{this}{exists}{partition_sql}{sep}{expression_sql}{conflict}{returning}"
return self.prepend_ctes(expression, sql)
def intersect_sql(self, expression: exp.Intersect) -> str:
@@ -1004,6 +1061,19 @@ class Generator:
def pseudotype_sql(self, expression: exp.PseudoType) -> str:
return expression.name.upper()
+ def onconflict_sql(self, expression: exp.OnConflict) -> str:
+ conflict = "ON DUPLICATE KEY" if expression.args.get("duplicate") else "ON CONFLICT"
+ constraint = self.sql(expression, "constraint")
+ if constraint:
+ constraint = f"ON CONSTRAINT {constraint}"
+ key = self.expressions(expression, key="key", flat=True)
+ do = "" if expression.args.get("duplicate") else " DO "
+ nothing = "NOTHING" if expression.args.get("nothing") else ""
+ expressions = self.expressions(expression, flat=True)
+ if expressions:
+ expressions = f"UPDATE SET {expressions}"
+ return f"{self.seg(conflict)} {constraint}{key}{do}{nothing}{expressions}"
+
def returning_sql(self, expression: exp.Returning) -> str:
return f"{self.seg('RETURNING')} {self.expressions(expression, flat=True)}"
@@ -1036,7 +1106,7 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f"{sep}{alias}" if alias else ""
hints = self.expressions(expression, key="hints", sep=", ", flat=True)
- hints = f" WITH ({hints})" if hints else ""
+ hints = f" WITH ({hints})" if hints and self.TABLE_HINTS else ""
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
@@ -1053,7 +1123,7 @@ class Generator:
this = self.sql(expression, "this")
alias = ""
method = self.sql(expression, "method")
- method = f"{method.upper()} " if method else ""
+ method = f"{method.upper()} " if method and self.TABLESAMPLE_WITH_METHOD else ""
numerator = self.sql(expression, "bucket_numerator")
denominator = self.sql(expression, "bucket_denominator")
field = self.sql(expression, "bucket_field")
@@ -1064,6 +1134,8 @@ class Generator:
rows = self.sql(expression, "rows")
rows = f"{rows} ROWS" if rows else ""
size = self.sql(expression, "size")
+ if size and self.TABLESAMPLE_SIZE_IS_PERCENT:
+ size = f"{size} PERCENT"
seed = self.sql(expression, "seed")
seed = f" {seed_prefix} ({seed})" if seed else ""
kind = expression.args.get("kind", "TABLESAMPLE")
@@ -1154,6 +1226,7 @@ class Generator:
"NATURAL" if expression.args.get("natural") else None,
expression.side,
expression.kind,
+ expression.hint if self.JOIN_HINTS else None,
"JOIN",
)
if op
@@ -1311,16 +1384,20 @@ class Generator:
def matchrecognize_sql(self, expression: exp.MatchRecognize) -> str:
partition = self.partition_by_sql(expression)
order = self.sql(expression, "order")
- measures = self.sql(expression, "measures")
- measures = self.seg(f"MEASURES {measures}") if measures else ""
+ measures = self.expressions(expression, key="measures")
+ measures = self.seg(f"MEASURES{self.seg(measures)}") if measures else ""
rows = self.sql(expression, "rows")
rows = self.seg(rows) if rows else ""
after = self.sql(expression, "after")
after = self.seg(after) if after else ""
pattern = self.sql(expression, "pattern")
pattern = self.seg(f"PATTERN ({pattern})") if pattern else ""
- define = self.sql(expression, "define")
- define = self.seg(f"DEFINE {define}") if define else ""
+ definition_sqls = [
+ f"{self.sql(definition, 'alias')} AS {self.sql(definition, 'this')}"
+ for definition in expression.args.get("define", [])
+ ]
+ definitions = self.expressions(sqls=definition_sqls)
+ define = self.seg(f"DEFINE{self.seg(definitions)}") if definitions else ""
body = "".join(
(
partition,
@@ -1332,7 +1409,9 @@ class Generator:
define,
)
)
- return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
+ alias = self.sql(expression, "alias")
+ alias = f" {alias}" if alias else ""
+ return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}{alias}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
limit = expression.args.get("limit")
@@ -1353,7 +1432,7 @@ class Generator:
self.sql(expression, "group"),
self.sql(expression, "having"),
self.sql(expression, "qualify"),
- self.seg("WINDOW ") + self.expressions(expression, "windows", flat=True)
+ self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else "",
self.sql(expression, "distribute"),
@@ -1471,15 +1550,21 @@ class Generator:
partition_sql = partition + " " if partition and order else partition
spec = expression.args.get("spec")
- spec_sql = " " + self.window_spec_sql(spec) if spec else ""
+ spec_sql = " " + self.windowspec_sql(spec) if spec else ""
alias = self.sql(expression, "alias")
- this = f"{this} {'AS' if expression.arg_key == 'windows' else 'OVER'}"
+ over = self.sql(expression, "over") or "OVER"
+ this = f"{this} {'AS' if expression.arg_key == 'windows' else over}"
+
+ first = expression.args.get("first")
+ if first is not None:
+ first = " FIRST " if first else " LAST "
+ first = first or ""
if not partition and not order and not spec and alias:
return f"{this} {alias}"
- window_args = alias + partition_sql + order_sql + spec_sql
+ window_args = alias + first + partition_sql + order_sql + spec_sql
return f"{this} ({window_args.strip()})"
@@ -1487,7 +1572,7 @@ class Generator:
partition = self.expressions(expression, key="partition_by", flat=True)
return f"PARTITION BY {partition}" if partition else ""
- def window_spec_sql(self, expression: exp.WindowSpec) -> str:
+ def windowspec_sql(self, expression: exp.WindowSpec) -> str:
kind = self.sql(expression, "kind")
start = csv(self.sql(expression, "start"), self.sql(expression, "start_side"), sep=" ")
end = (
@@ -1508,7 +1593,7 @@ class Generator:
return f"{this} BETWEEN {low} AND {high}"
def bracket_sql(self, expression: exp.Bracket) -> str:
- expressions = apply_index_offset(expression.expressions, self.index_offset)
+ expressions = apply_index_offset(expression.this, expression.expressions, self.index_offset)
expressions_sql = ", ".join(self.sql(e) for e in expressions)
return f"{self.sql(expression, 'this')}[{expressions_sql}]"
@@ -1550,6 +1635,11 @@ class Generator:
expressions = self.expressions(expression, flat=True)
return f"CONSTRAINT {this} {expressions}"
+ def nextvaluefor_sql(self, expression: exp.NextValueFor) -> str:
+ order = expression.args.get("order")
+ order = f" OVER ({self.order_sql(order, flat=True)})" if order else ""
+ return f"NEXT VALUE FOR {self.sql(expression, 'this')}{order}"
+
def extract_sql(self, expression: exp.Extract) -> str:
this = self.sql(expression, "this")
expression_sql = self.sql(expression, "expression")
@@ -1586,7 +1676,7 @@ class Generator:
def primarykey_sql(self, expression: exp.ForeignKey) -> str:
expressions = self.expressions(expression, flat=True)
- options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"PRIMARY KEY ({expressions}){options}"
@@ -1644,17 +1734,20 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
- this = expression.args.get("this")
- if this:
- this = (
- f" {this}"
- if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
- else f" ({this})"
- )
- else:
- this = ""
unit = self.sql(expression, "unit")
+ if not self.INTERVAL_ALLOWS_PLURAL_FORM:
+ unit = self.TIME_PART_SINGULARS.get(unit.lower(), unit)
unit = f" {unit}" if unit else ""
+
+ if self.SINGLE_STRING_INTERVAL:
+ this = expression.this.name if expression.this else ""
+ return f"INTERVAL '{this}{unit}'"
+
+ this = self.sql(expression, "this")
+ if this:
+ unwrapped = isinstance(expression.this, self.UNWRAPPED_INTERVAL_VALUES)
+ this = f" {this}" if unwrapped else f" ({this})"
+
return f"INTERVAL{this}{unit}"
def return_sql(self, expression: exp.Return) -> str:
@@ -1664,7 +1757,7 @@ class Generator:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
expressions = f"({expressions})" if expressions else ""
- options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = self.expressions(expression, key="options", flat=True, sep=" ")
options = f" {options}" if options else ""
return f"REFERENCES {this}{expressions}{options}"
@@ -1690,9 +1783,9 @@ class Generator:
return f"NOT {self.sql(expression, 'this')}"
def alias_sql(self, expression: exp.Alias) -> str:
- to_sql = self.sql(expression, "alias")
- to_sql = f" AS {to_sql}" if to_sql else ""
- return f"{self.sql(expression, 'this')}{to_sql}"
+ alias = self.sql(expression, "alias")
+ alias = f" AS {alias}" if alias else ""
+ return f"{self.sql(expression, 'this')}{alias}"
def aliases_sql(self, expression: exp.Aliases) -> str:
return f"{self.sql(expression, 'this')} AS ({self.expressions(expression, flat=True)})"
@@ -1712,7 +1805,11 @@ class Generator:
if not self.pretty:
return self.binary(expression, op)
- sqls = tuple(self.sql(e) for e in expression.flatten(unnest=False))
+ sqls = tuple(
+ self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e)
+ for i, e in enumerate(expression.flatten(unnest=False))
+ )
+
sep = "\n" if self.text_width(sqls) > self._max_text_width else " "
return f"{sep}{op} ".join(sqls)
@@ -1797,13 +1894,13 @@ class Generator:
actions = expression.args["actions"]
if isinstance(actions[0], exp.ColumnDef):
- actions = self.expressions(expression, "actions", prefix="ADD COLUMN ")
+ actions = self.expressions(expression, key="actions", prefix="ADD COLUMN ")
elif isinstance(actions[0], exp.Schema):
- actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
+ actions = self.expressions(expression, key="actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Delete):
- actions = self.expressions(expression, "actions", flat=True)
+ actions = self.expressions(expression, key="actions", flat=True)
else:
- actions = self.expressions(expression, "actions")
+ actions = self.expressions(expression, key="actions")
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
@@ -1935,6 +2032,7 @@ class Generator:
return f"USE{kind}{this}"
def binary(self, expression: exp.Binary, op: str) -> str:
+ op = self.maybe_comment(op, comments=expression.comments)
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"
def function_fallback_sql(self, expression: exp.Func) -> str:
@@ -1965,14 +2063,15 @@ class Generator:
def expressions(
self,
- expression: exp.Expression,
+ expression: t.Optional[exp.Expression] = None,
key: t.Optional[str] = None,
+ sqls: t.Optional[t.List[str]] = None,
flat: bool = False,
indent: bool = True,
sep: str = ", ",
prefix: str = "",
) -> str:
- expressions = expression.args.get(key or "expressions")
+ expressions = expression.args.get(key or "expressions") if expression else sqls
if not expressions:
return ""