diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 64 |
1 files changed, 48 insertions, 16 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index bd12d54..d7dcea0 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -25,6 +25,12 @@ class Generator: quote_end (str): specifies which ending character to use to delimit quotes. Default: '. identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". + bit_start (str): specifies which starting character to use to delimit bit literals. Default: None. + bit_end (str): specifies which ending character to use to delimit bit literals. Default: None. + hex_start (str): specifies which starting character to use to delimit hex literals. Default: None. + hex_end (str): specifies which ending character to use to delimit hex literals. Default: None. + byte_start (str): specifies which starting character to use to delimit byte literals. Default: None. + byte_end (str): specifies which ending character to use to delimit byte literals. Default: None. identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. normalize (bool): if set to True all identifiers will lower cased string_escape (str): specifies a string escape character. Default: '. @@ -227,6 +233,12 @@ class Generator: "quote_end", "identifier_start", "identifier_end", + "bit_start", + "bit_end", + "hex_start", + "hex_end", + "byte_start", + "byte_end", "identify", "normalize", "string_escape", @@ -258,6 +270,12 @@ class Generator: quote_end=None, identifier_start=None, identifier_end=None, + bit_start=None, + bit_end=None, + hex_start=None, + hex_end=None, + byte_start=None, + byte_end=None, identify=False, normalize=False, string_escape=None, @@ -284,6 +302,12 @@ class Generator: self.quote_end = quote_end or "'" self.identifier_start = identifier_start or '"' self.identifier_end = identifier_end or '"' + self.bit_start = bit_start + self.bit_end = bit_end + self.hex_start = hex_start + self.hex_end = hex_end + self.byte_start = byte_start + self.byte_end = byte_end self.identify = identify self.normalize = normalize self.string_escape = string_escape or "'" @@ -361,7 +385,7 @@ class Generator: expression: t.Optional[exp.Expression] = None, comments: t.Optional[t.List[str]] = None, ) -> str: - comments = (comments or (expression and expression.comments)) if self._comments else None # type: ignore + comments = ((expression and expression.comments) if comments is None else comments) if self._comments else None # type: ignore if not comments or isinstance(expression, exp.Binary): return sql @@ -510,12 +534,12 @@ class Generator: position = self.sql(expression, "position") return f"{position}{this}" - def columndef_sql(self, expression: exp.ColumnDef) -> str: + def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str: column = self.sql(expression, "this") kind = self.sql(expression, "kind") constraints = self.expressions(expression, key="constraints", sep=" ", flat=True) exists = "IF NOT EXISTS " if expression.args.get("exists") else "" - kind = f" {kind}" if kind else "" + kind = f"{sep}{kind}" if kind else "" constraints = f" {constraints}" if constraints else "" position = self.sql(expression, "position") position = f" {position}" if position else "" @@ -524,7 +548,7 @@ class Generator: def columnconstraint_sql(self, expression: exp.ColumnConstraint) -> str: this = self.sql(expression, "this") - kind_sql = self.sql(expression, "kind") + kind_sql = self.sql(expression, "kind").strip() return f"CONSTRAINT {this} {kind_sql}" if this else kind_sql def autoincrementcolumnconstraint_sql(self, _) -> str: @@ -716,13 +740,22 @@ class Generator: return f"{alias}{columns}" def bitstring_sql(self, expression: exp.BitString) -> str: - return self.sql(expression, "this") + this = self.sql(expression, "this") + if self.bit_start: + return f"{self.bit_start}{this}{self.bit_end}" + return f"{int(this, 2)}" def hexstring_sql(self, expression: exp.HexString) -> str: - return self.sql(expression, "this") + this = self.sql(expression, "this") + if self.hex_start: + return f"{self.hex_start}{this}{self.hex_end}" + return f"{int(this, 16)}" def bytestring_sql(self, expression: exp.ByteString) -> str: - return self.sql(expression, "this") + this = self.sql(expression, "this") + if self.byte_start: + return f"{self.byte_start}{this}{self.byte_end}" + return this def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this @@ -1115,10 +1148,12 @@ class Generator: return f"{table}{system_time}{alias}{hints}{laterals}{joins}{pivots}" - def tablesample_sql(self, expression: exp.TableSample, seed_prefix: str = "SEED") -> str: + def tablesample_sql( + self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS " + ) -> str: if self.alias_post_tablesample and expression.this.alias: this = self.sql(expression.this, "this") - alias = f" AS {self.sql(expression.this, 'alias')}" + alias = f"{sep}{self.sql(expression.this, 'alias')}" else: this = self.sql(expression, "this") alias = "" @@ -1447,16 +1482,16 @@ class Generator: ) def select_sql(self, expression: exp.Select) -> str: - kind = expression.args.get("kind") - kind = f" AS {kind}" if kind else "" hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" + kind = expression.args.get("kind") + kind = f" AS {kind}" if kind else "" expressions = self.expressions(expression) expressions = f"{self.sep()}{expressions}" if expressions else expressions sql = self.query_modifiers( expression, - f"SELECT{kind}{hint}{distinct}{expressions}", + f"SELECT{hint}{distinct}{kind}{expressions}", self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) @@ -1475,9 +1510,6 @@ class Generator: replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else "" return f"*{except_}{replace}" - def structkwarg_sql(self, expression: exp.StructKwarg) -> str: - return f"{self.sql(expression, 'this')} {self.sql(expression, 'expression')}" - def parameter_sql(self, expression: exp.Parameter) -> str: this = self.sql(expression, "this") this = f"{{{this}}}" if expression.args.get("wrapped") else f"{this}" @@ -1806,7 +1838,7 @@ class Generator: return self.binary(expression, op) sqls = tuple( - self.maybe_comment(self.sql(e), e, e.parent.comments) if i != 1 else self.sql(e) + self.maybe_comment(self.sql(e), e, e.parent.comments or []) if i != 1 else self.sql(e) for i, e in enumerate(expression.flatten(unnest=False)) ) |