summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py54
1 files changed, 40 insertions, 14 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index b95e9bc..0d72fe3 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -1,6 +1,7 @@
from __future__ import annotations
import logging
+import re
import typing as t
from sqlglot import exp
@@ -11,6 +12,8 @@ from sqlglot.tokens import TokenType
logger = logging.getLogger("sqlglot")
+BACKSLASH_RE = re.compile(r"\\(?!b|f|n|r|t|0)")
+
class Generator:
"""
@@ -28,7 +31,8 @@ class Generator:
identify (bool): if set to True all identifiers will be delimited by the corresponding
character.
normalize (bool): if set to True all identifiers will lower cased
- escape (str): specifies an escape character. Default: '.
+ string_escape (str): specifies a string escape character. Default: '.
+ identifier_escape (str): specifies an identifier escape character. Default: ".
pad (int): determines padding in a formatted string. Default: 2.
indent (int): determines the size of indentation in a formatted string. Default: 4.
unnest_column_only (bool): if true unnest table aliases are considered only as column aliases
@@ -85,6 +89,9 @@ class Generator:
# Wrap derived values in parens, usually standard but spark doesn't support it
WRAP_DERIVED_VALUES = True
+ # Whether or not create function uses an AS before the def.
+ CREATE_FUNCTION_AS = True
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -154,7 +161,8 @@ class Generator:
"identifier_end",
"identify",
"normalize",
- "escape",
+ "string_escape",
+ "identifier_escape",
"pad",
"index_offset",
"unnest_column_only",
@@ -167,6 +175,7 @@ class Generator:
"_indent",
"_replace_backslash",
"_escaped_quote_end",
+ "_escaped_identifier_end",
"_leading_comma",
"_max_text_width",
"_comments",
@@ -183,7 +192,8 @@ class Generator:
identifier_end=None,
identify=False,
normalize=False,
- escape=None,
+ string_escape=None,
+ identifier_escape=None,
pad=2,
indent=2,
index_offset=0,
@@ -208,7 +218,8 @@ class Generator:
self.identifier_end = identifier_end or '"'
self.identify = identify
self.normalize = normalize
- self.escape = escape or "'"
+ self.string_escape = string_escape or "'"
+ self.identifier_escape = identifier_escape or '"'
self.pad = pad
self.index_offset = index_offset
self.unnest_column_only = unnest_column_only
@@ -219,8 +230,9 @@ class Generator:
self.max_unsupported = max_unsupported
self.null_ordering = null_ordering
self._indent = indent
- self._replace_backslash = self.escape == "\\"
- self._escaped_quote_end = self.escape + self.quote_end
+ self._replace_backslash = self.string_escape == "\\"
+ self._escaped_quote_end = self.string_escape + self.quote_end
+ self._escaped_identifier_end = self.identifier_escape + self.identifier_end
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
@@ -441,6 +453,9 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
+ this = ""
+ if expression.this is not None:
+ this = " ALWAYS " if expression.this else " BY DEFAULT "
start = expression.args.get("start")
start = f"START WITH {start}" if start else ""
increment = expression.args.get("increment")
@@ -449,9 +464,7 @@ class Generator:
if start or increment:
sequence_opts = f"{start} {increment}"
sequence_opts = f" ({sequence_opts.strip()})"
- return (
- f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
- )
+ return f"GENERATED{this}AS IDENTITY{sequence_opts}"
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@@ -496,7 +509,12 @@ class Generator:
properties_sql = self.sql(properties_exp, "properties")
begin = " BEGIN" if expression.args.get("begin") else ""
expression_sql = self.sql(expression, "expression")
- expression_sql = f" AS{begin}{self.sep()}{expression_sql}" if expression_sql else ""
+ if expression_sql:
+ expression_sql = f"{begin}{self.sep()}{expression_sql}"
+
+ if self.CREATE_FUNCTION_AS or kind != "FUNCTION":
+ expression_sql = f" AS{expression_sql}"
+
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
transient = (
" TRANSIENT" if self.CREATE_TRANSIENT and expression.args.get("transient") else ""
@@ -701,6 +719,7 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
text = text.lower() if self.normalize else text
+ text = text.replace(self.identifier_end, self._escaped_identifier_end)
if expression.args.get("quoted") or self.identify:
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@@ -1121,7 +1140,7 @@ class Generator:
text = expression.this or ""
if expression.is_string:
if self._replace_backslash:
- text = text.replace("\\", "\\\\")
+ text = BACKSLASH_RE.sub(r"\\\\", text)
text = text.replace(self.quote_end, self._escaped_quote_end)
if self.pretty:
text = text.replace("\n", self.SENTINEL_LINE_BREAK)
@@ -1486,9 +1505,16 @@ class Generator:
return f"(SELECT {self.sql(unnest)})"
def interval_sql(self, expression: exp.Interval) -> str:
- this = self.sql(expression, "this")
- this = f" {this}" if this else ""
- unit = self.sql(expression, "unit")
+ this = expression.args.get("this")
+ if this:
+ this = (
+ f" {this}"
+ if isinstance(this, exp.Literal) or isinstance(this, exp.Paren)
+ else f" ({this})"
+ )
+ else:
+ this = ""
+ unit = expression.args.get("unit")
unit = f" {unit}" if unit else ""
return f"INTERVAL{this}{unit}"