summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py67
1 files changed, 55 insertions, 12 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 6375d92..b398d8e 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -16,7 +16,7 @@ class Generator:
"""
Generator interprets the given syntax tree and produces a SQL string as an output.
- Args
+ Args:
time_mapping (dict): the dictionary of custom time mappings in which the key
represents a python time format and the output the target time format
time_trie (trie): a trie of the time_mapping keys
@@ -84,6 +84,13 @@ class Generator:
exp.DataType.Type.NVARCHAR: "VARCHAR",
exp.DataType.Type.MEDIUMTEXT: "TEXT",
exp.DataType.Type.LONGTEXT: "TEXT",
+ exp.DataType.Type.MEDIUMBLOB: "BLOB",
+ exp.DataType.Type.LONGBLOB: "BLOB",
+ }
+
+ STAR_MAPPING = {
+ "except": "EXCEPT",
+ "replace": "REPLACE",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@@ -106,6 +113,8 @@ class Generator:
exp.TableFormatProperty,
}
+ WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint)
+
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
@@ -241,15 +250,17 @@ class Generator:
return sql
sep = "\n" if self.pretty else " "
- comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment)
+ comments_sql = sep.join(
+ f"/*{self.pad_comment(comment)}*/" for comment in comments if comment
+ )
- if not comments:
+ if not comments_sql:
return sql
if isinstance(expression, self.WITH_SEPARATED_COMMENTS):
- return f"{comments}{self.sep()}{sql}"
+ return f"{comments_sql}{self.sep()}{sql}"
- return f"{sql} {comments}"
+ return f"{sql} {comments_sql}"
def wrap(self, expression: exp.Expression | str) -> str:
this_sql = self.indent(
@@ -433,8 +444,9 @@ class Generator:
def create_sql(self, expression: exp.Create) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind").upper()
+ begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
- expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else ""
+ expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@@ -741,12 +753,14 @@ class Generator:
laterals = self.expressions(expression, key="laterals", sep="")
joins = self.expressions(expression, key="joins", sep="")
pivots = self.expressions(expression, key="pivots", sep="")
+ system_time = expression.args.get("system_time")
+ system_time = f" {self.sql(expression, 'system_time')}" if system_time else ""
if alias and pivots:
pivots = f"{pivots}{alias}"
alias = ""
- return f"{table}{alias}{hints}{laterals}{joins}{pivots}"
+ return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}"
def tablesample_sql(self, expression: exp.TableSample) -> str:
if self.alias_post_tablesample and expression.this.alias:
@@ -1009,9 +1023,9 @@ class Generator:
def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
- except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else ""
+ except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
- replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
+ replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
def structkwarg_sql(self, expression: exp.StructKwarg) -> str:
@@ -1193,6 +1207,12 @@ class Generator:
update = f" ON UPDATE {update}" if update else ""
return f"FOREIGN KEY ({expressions}){reference}{delete}{update}"
+ def primarykey_sql(self, expression: exp.ForeignKey) -> str:
+ expressions = self.expressions(expression, flat=True)
+ options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = f" {options}" if options else ""
+ return f"PRIMARY KEY ({expressions}){options}"
+
def unique_sql(self, expression: exp.Unique) -> str:
columns = self.expressions(expression, key="expressions")
return f"UNIQUE ({columns})"
@@ -1229,10 +1249,16 @@ class Generator:
unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}"
+ def return_sql(self, expression: exp.Return) -> str:
+ return f"RETURN {self.sql(expression, 'this')}"
+
def reference_sql(self, expression: exp.Reference) -> str:
this = self.sql(expression, "this")
expressions = self.expressions(expression, flat=True)
- return f"REFERENCES {this}({expressions})"
+ expressions = f"({expressions})" if expressions else ""
+ options = self.expressions(expression, "options", flat=True, sep=" ")
+ options = f" {options}" if options else ""
+ return f"REFERENCES {this}{expressions}{options}"
def anonymous_sql(self, expression: exp.Anonymous) -> str:
args = self.format_args(*expression.expressions)
@@ -1362,7 +1388,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
- elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
+ elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
@@ -1370,6 +1396,17 @@ class Generator:
exists = " IF EXISTS" if expression.args.get("exists") else ""
return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
+ def addconstraint_sql(self, expression: exp.AddConstraint) -> str:
+ this = self.sql(expression, "this")
+ expression_ = self.sql(expression, "expression")
+ add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD"
+
+ enforced = expression.args.get("enforced")
+ if enforced is not None:
+ return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}"
+
+ return f"{add_constraint} {expression_}"
+
def distinct_sql(self, expression: exp.Distinct) -> str:
this = self.expressions(expression, flat=True)
this = f" {this}" if this else ""
@@ -1550,13 +1587,19 @@ class Generator:
expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}"
)
+ def tag_sql(self, expression: exp.Tag) -> str:
+ return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}"
+
def token_sql(self, token_type: TokenType) -> str:
return self.TOKEN_MAPPING.get(token_type, token_type.name)
def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str:
this = self.sql(expression, "this")
expressions = self.no_identify(lambda: self.expressions(expression))
- return f"{this}({expressions})"
+ expressions = (
+ self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}"
+ )
+ return f"{this}{expressions}"
def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str:
this = self.sql(expression, "this")