diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 31 |
1 files changed, 23 insertions, 8 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 5936649..a6f4772 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -5,7 +5,7 @@ import typing as t from sqlglot import exp from sqlglot.errors import ErrorLevel, UnsupportedError, concat_messages -from sqlglot.helper import apply_index_offset, csv, seq_get +from sqlglot.helper import apply_index_offset, csv, seq_get, should_identify from sqlglot.time import format_time from sqlglot.tokens import TokenType @@ -25,8 +25,7 @@ class Generator: quote_end (str): specifies which ending character to use to delimit quotes. Default: '. identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". - identify (bool): if set to True all identifiers will be delimited by the corresponding - character. + identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. normalize (bool): if set to True all identifiers will lower cased string_escape (str): specifies a string escape character. Default: '. identifier_escape (str): specifies an identifier escape character. Default: ". @@ -57,10 +56,10 @@ class Generator: TRANSFORMS = { exp.DateAdd: lambda self, e: self.func( - "DATE_ADD", e.this, e.expression, e.args.get("unit") + "DATE_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) ), exp.TsOrDsAdd: lambda self, e: self.func( - "TS_OR_DS_ADD", e.this, e.expression, e.args.get("unit") + "TS_OR_DS_ADD", e.this, e.expression, exp.Literal.string(e.text("unit")) ), exp.VarMap: lambda self, e: self.func("MAP", e.args["keys"], e.args["values"]), exp.CharacterSetProperty: lambda self, e: f"{'DEFAULT ' if e.args.get('default') else ''}CHARACTER SET={self.sql(e, 'this')}", @@ -736,7 +735,7 @@ class Generator: text = expression.name text = text.lower() if self.normalize else text text = text.replace(self.identifier_end, self._escaped_identifier_end) - if expression.args.get("quoted") or self.identify: + if expression.args.get("quoted") or should_identify(text, self.identify): text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1176,6 +1175,22 @@ class Generator: this = self.sql(expression, "this") return f"{this}{self.seg('OFFSET')} {self.sql(expression, 'expression')}" + def setitem_sql(self, expression: exp.SetItem) -> str: + kind = self.sql(expression, "kind") + kind = f"{kind} " if kind else "" + this = self.sql(expression, "this") + expressions = self.expressions(expression) + collate = self.sql(expression, "collate") + collate = f" COLLATE {collate}" if collate else "" + global_ = "GLOBAL " if expression.args.get("global") else "" + return f"{global_}{kind}{this}{expressions}{collate}" + + def set_sql(self, expression: exp.Set) -> str: + expressions = ( + f" {self.expressions(expression, flat=True)}" if expression.expressions else "" + ) + return f"SET{expressions}" + def lock_sql(self, expression: exp.Lock) -> str: if self.LOCKING_READS_SUPPORTED: lock_type = "UPDATE" if expression.args["update"] else "SHARE" @@ -1359,8 +1374,8 @@ class Generator: sql = self.query_modifiers( expression, self.wrap(expression), - self.expressions(expression, key="pivots", sep=" "), alias, + self.expressions(expression, key="pivots", sep=" "), ) return self.prepend_ctes(expression, sql) @@ -1668,7 +1683,7 @@ class Generator: expression_sql = self.sql(expression, "expression") return f"COMMENT{exists_sql}ON {kind} {this} IS {expression_sql}" - def transaction_sql(self, *_) -> str: + def transaction_sql(self, expression: exp.Transaction) -> str: return "BEGIN" def commit_sql(self, expression: exp.Commit) -> str: |