From 918abde014f9e5c75dfbe21110c379f7f70435c9 Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Sun, 12 Feb 2023 11:06:28 +0100 Subject: Merging upstream version 11.0.1. Signed-off-by: Daniel Baumann --- sqlglot/generator.py | 54 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 40 insertions(+), 14 deletions(-) (limited to 'sqlglot/generator.py') diff --git a/sqlglot/generator.py b/sqlglot/generator.py index b95e9bc..0d72fe3 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import re import typing as t from sqlglot import exp @@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType logger = logging.getLogger("sqlglot") +BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)") + class Generator: """ @@ -28,7 +31,8 @@ class Generator: identify (bool): if set to True all identifiers will be delimited by the corresponding character. normalize (bool): if set to True all identifiers will lower cased - escape (str): specifies an escape character. Default: '. + string_escape (str): specifies a string escape character. Default: '. + identifier_escape (str): specifies an identifier escape character. Default: ". pad (int): determines padding in a formatted string. Default: 2. indent (int): determines the size of indentation in a formatted string. Default: 4. unnest_column_only (bool): if true unnest table aliases are considered only as column aliases @@ -85,6 +89,9 @@ class Generator: # Wrap derived values in parens, usually standard but spark doesn't support it WRAP_DERIVED_VALUES = True + # Whether or not create function uses an AS before the def. + CREATE_FUNCTION_AS = True + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -154,7 +161,8 @@ class Generator: "identifier_end", "identify", "normalize", - "escape", + "string_escape", + "identifier_escape", "pad", "index_offset", "unnest_column_only", @@ -167,6 +175,7 @@ class Generator: "_indent", "_replace_backslash", "_escaped_quote_end", + "_escaped_identifier_end", "_leading_comma", "_max_text_width", "_comments", @@ -183,7 +192,8 @@ class Generator: identifier_end=None, identify=False, normalize=False, - escape=None, + string_escape=None, + identifier_escape=None, pad=2, indent=2, index_offset=0, @@ -208,7 +218,8 @@ class Generator: self.identifier_end = identifier_end or '"' self.identify = identify self.normalize = normalize - self.escape = escape or "'" + self.string_escape = string_escape or "'" + self.identifier_escape = identifier_escape or '"' self.pad = pad self.index_offset = index_offset self.unnest_column_only = unnest_column_only @@ -219,8 +230,9 @@ class Generator: self.max_unsupported = max_unsupported self.null_ordering = null_ordering self._indent = indent - self._replace_backslash = self.escape == "\\" - self._escaped_quote_end = self.escape + self.quote_end + self._replace_backslash = self.string_escape == "\\" + self._escaped_quote_end = self.string_escape + self.quote_end + self._escaped_identifier_end = self.identifier_escape + self.identifier_end self._leading_comma = leading_comma self._max_text_width = max_text_width self._comments = comments @@ -441,6 +453,9 @@ class Generator: def generatedasidentitycolumnconstraint_sql( self, expression: exp.GeneratedAsIdentityColumnConstraint ) -> str: + this = "" + if expression.this is not None: + this = " ALWAYS " if expression.this else " BY DEFAULT " start = expression.args.get("start") start = f"START WITH {start}" if start else "" increment = expression.args.get("increment") @@ -449,9 +464,7 @@ class Generator: if start or increment: sequence_opts = f"{start} {increment}" sequence_opts = f" ({sequence_opts.strip()})" - return ( - f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}" - ) + return f"GENERATED{this}AS IDENTITY{sequence_opts}" def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str: return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL" @@ -496,7 +509,12 @@ class Generator: properties_sql = self.sql(properties_exp, "properties") begin = " BEGIN" if expression.args.get("begin") else "" expression_sql = self.sql(expression, "expression") - expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else "" + if expression_sql: + expression_sql = f"{begin}{self.sep()}{expression_sql}" + + if self.CREATE_FUNCTION_AS or kind != "FUNCTION": + expression_sql = f" AS{expression_sql}" + temporary = " TEMPORARY" if expression.args.get("temporary") else "" transient = ( " TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else "" @@ -701,6 +719,7 @@ class Generator: def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name text = text.lower() if self.normalize else text + text = text.replace(self.identifier_end, self._escaped_identifier_end) if expression.args.get("quoted") or self.identify: text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1121,7 +1140,7 @@ class Generator: text = expression.this or "" if expression.is_string: if self._replace_backslash: - text = text.replace("\\", "\\\\") + text = BACKSLASH_RE.sub(r"\\\\", text) text = text.replace(self.quote_end, self._escaped_quote_end) if self.pretty: text = text.replace("\n", self.SENTINEL_LINE_BREAK) @@ -1486,9 +1505,16 @@ class Generator: return f"(SELECT {self.sql(unnest)})" def interval_sql(self, expression: exp.Interval) -> str: - this = self.sql(expression, "this") - this = f" {this}" if this else "" - unit = self.sql(expression, "unit") + this = expression.args.get("this") + if this: + this = ( + f" {this}" + if isinstance(this, exp.Literal) or isinstance(this, exp.Paren) + else f" ({this})" + ) + else: + this = "" + unit = expression.args.get("unit") unit = f" {unit}" if unit else "" return f"INTERVAL{this}{unit}" -- cgit v1.2.3