diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 67 |
1 files changed, 55 insertions, 12 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 6375d92..b398d8e 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -16,7 +16,7 @@ class Generator: """ Generator interprets the given syntax tree and produces a SQL string as an output. - Args + Args: time_mapping (dict): the dictionary of custom time mappings in which the key represents a python time format and the output the target time format time_trie (trie): a trie of the time_mapping keys @@ -84,6 +84,13 @@ class Generator: exp.DataType.Type.NVARCHAR: "VARCHAR", exp.DataType.Type.MEDIUMTEXT: "TEXT", exp.DataType.Type.LONGTEXT: "TEXT", + exp.DataType.Type.MEDIUMBLOB: "BLOB", + exp.DataType.Type.LONGBLOB: "BLOB", + } + + STAR_MAPPING = { + "except": "EXCEPT", + "replace": "REPLACE", } TOKEN_MAPPING: t.Dict[TokenType, str] = {} @@ -106,6 +113,8 @@ class Generator: exp.TableFormatProperty, } + WITH_SINGLE_ALTER_TABLE_ACTION = (exp.AlterColumn, exp.RenameTable, exp.AddConstraint) + WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary) SENTINEL_LINE_BREAK = "__SQLGLOT__LB__" @@ -241,15 +250,17 @@ class Generator: return sql sep = "\n" if self.pretty else " " - comments = sep.join(f"/*{self.pad_comment(comment)}*/" for comment in comments if comment) + comments_sql = sep.join( + f"/*{self.pad_comment(comment)}*/" for comment in comments if comment + ) - if not comments: + if not comments_sql: return sql if isinstance(expression, self.WITH_SEPARATED_COMMENTS): - return f"{comments}{self.sep()}{sql}" + return f"{comments_sql}{self.sep()}{sql}" - return f"{sql} {comments}" + return f"{sql} {comments_sql}" def wrap(self, expression: exp.Expression | str) -> str: this_sql = self.indent( @@ -433,8 +444,9 @@ class Generator: def create_sql(self, expression: exp.Create) -> str: this = self.sql(expression, "this") kind = self.sql(expression, "kind").upper() + begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") - expression_sql = f" AS{self.sep()}{expression_sql}" if expression_sql else "" + expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" @@ -741,12 +753,14 @@ class Generator: laterals = self.expressions(expression, key="laterals", sep="") joins = self.expressions(expression, key="joins", sep="") pivots = self.expressions(expression, key="pivots", sep="") + system_time = expression.args.get("system_time") + system_time = f" {self.sql(expression, 'system_time')}" if system_time else "" if alias and pivots: pivots = f"{pivots}{alias}" alias = "" - return f"{table}{alias}{hints}{laterals}{joins}{pivots}" + return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" def tablesample_sql(self, expression: exp.TableSample) -> str: if self.alias_post_tablesample and expression.this.alias: @@ -1009,9 +1023,9 @@ class Generator: def star_sql(self, expression: exp.Star) -> str: except_ = self.expressions(expression, key="except", flat=True) - except_ = f"{self.seg('EXCEPT')} ({except_})" if except_ else "" + except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else "" replace = self.expressions(expression, key="replace", flat=True) - replace = f"{self.seg('REPLACE')} ({replace})" if replace else "" + replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else "" return f"*{except_}{replace}" def structkwarg_sql(self, expression: exp.StructKwarg) -> str: @@ -1193,6 +1207,12 @@ class Generator: update = f" ON UPDATE {update}" if update else "" return f"FOREIGN KEY ({expressions}){reference}{delete}{update}" + def primarykey_sql(self, expression: exp.ForeignKey) -> str: + expressions = self.expressions(expression, flat=True) + options = self.expressions(expression, "options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"PRIMARY KEY ({expressions}){options}" + def unique_sql(self, expression: exp.Unique) -> str: columns = self.expressions(expression, key="expressions") return f"UNIQUE ({columns})" @@ -1229,10 +1249,16 @@ class Generator: unit = f" {unit}" if unit else "" return f"INTERVAL{this}{unit}" + def return_sql(self, expression: exp.Return) -> str: + return f"RETURN {self.sql(expression, 'this')}" + def reference_sql(self, expression: exp.Reference) -> str: this = self.sql(expression, "this") expressions = self.expressions(expression, flat=True) - return f"REFERENCES {this}({expressions})" + expressions = f"({expressions})" if expressions else "" + options = self.expressions(expression, "options", flat=True, sep=" ") + options = f" {options}" if options else "" + return f"REFERENCES {this}{expressions}{options}" def anonymous_sql(self, expression: exp.Anonymous) -> str: args = self.format_args(*expression.expressions) @@ -1362,7 +1388,7 @@ class Generator: actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ") elif isinstance(actions[0], exp.Drop): actions = self.expressions(expression, "actions") - elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)): + elif isinstance(actions[0], self.WITH_SINGLE_ALTER_TABLE_ACTION): actions = self.sql(actions[0]) else: self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}") @@ -1370,6 +1396,17 @@ class Generator: exists = " IF EXISTS" if expression.args.get("exists") else "" return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}" + def addconstraint_sql(self, expression: exp.AddConstraint) -> str: + this = self.sql(expression, "this") + expression_ = self.sql(expression, "expression") + add_constraint = f"ADD CONSTRAINT {this}" if this else "ADD" + + enforced = expression.args.get("enforced") + if enforced is not None: + return f"{add_constraint} CHECK ({expression_}){' ENFORCED' if enforced else ''}" + + return f"{add_constraint} {expression_}" + def distinct_sql(self, expression: exp.Distinct) -> str: this = self.expressions(expression, flat=True) this = f" {this}" if this else "" @@ -1550,13 +1587,19 @@ class Generator: expression, f"{this}{op}{self.sep()}{self.sql(expression, 'expression')}" ) + def tag_sql(self, expression: exp.Tag) -> str: + return f"{expression.args.get('prefix')}{self.sql(expression.this)}{expression.args.get('postfix')}" + def token_sql(self, token_type: TokenType) -> str: return self.TOKEN_MAPPING.get(token_type, token_type.name) def userdefinedfunction_sql(self, expression: exp.UserDefinedFunction) -> str: this = self.sql(expression, "this") expressions = self.no_identify(lambda: self.expressions(expression)) - return f"{this}({expressions})" + expressions = ( + self.wrap(expressions) if expression.args.get("wrapped") else f" {expressions}" + ) + return f"{this}{expressions}" def userdefinedfunctionkwarg_sql(self, expression: exp.UserDefinedFunctionKwarg) -> str: this = self.sql(expression, "this") |