summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py82
1 files changed, 59 insertions, 23 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 3935133..6375d92 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -82,6 +82,8 @@ class Generator:
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
+ exp.DataType.Type.MEDIUMTEXT: "TEXT",
+ exp.DataType.Type.LONGTEXT: "TEXT",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@@ -105,6 +107,7 @@ class Generator:
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
+ SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
"time_mapping",
@@ -211,6 +214,8 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
+ if self.pretty:
+ sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
def unsupported(self, message: str) -> None:
@@ -401,7 +406,17 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
- return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
+ start = expression.args.get("start")
+ start = f"START WITH {start}" if start else ""
+ increment = expression.args.get("increment")
+ increment = f"INCREMENT BY {increment}" if increment else ""
+ sequence_opts = ""
+ if start or increment:
+ sequence_opts = f"{start} {increment}"
+ sequence_opts = f" ({sequence_opts.strip()})"
+ return (
+ f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
+ )
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@@ -475,10 +490,13 @@ class Generator:
materialized,
)
)
+ no_schema_binding = (
+ " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
+ )
post_expression_modifiers = "".join((data, statistics, no_primary_index))
- expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@@ -517,13 +535,19 @@ class Generator:
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
interior = self.expressions(expression, flat=True)
+ values = ""
if interior:
- nested = (
- f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
- if expression.args.get("nested")
- else f"({interior})"
- )
- return f"{type_sql}{nested}"
+ if expression.args.get("nested"):
+ nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
+ if expression.args.get("values") is not None:
+ delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
+ values = (
+ f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
+ )
+ else:
+ nested = f"({interior})"
+
+ return f"{type_sql}{nested}{values}"
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
@@ -622,10 +646,14 @@ class Generator:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
- def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
+ def properties(
+ self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = ""
+ ) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
+ return (
+ f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}"
+ )
return ""
def with_properties(self, properties: exp.Properties) -> str:
@@ -763,14 +791,15 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
- alias = self.sql(expression, "alias")
args = self.expressions(expression)
- if not alias:
- return f"VALUES{self.seg('')}{args}"
- alias = f" AS {alias}" if alias else alias
- if self.WRAP_DERIVED_VALUES:
- return f"(VALUES{self.seg('')}{args}){alias}"
- return f"VALUES{self.seg('')}{args}{alias}"
+ alias = self.sql(expression, "alias")
+ values = f"VALUES{self.seg('')}{args}"
+ values = (
+ f"({values})"
+ if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
+ else values
+ )
+ return f"{values} AS {alias}" if alias else values
def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
@@ -868,6 +897,8 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
+ if self.pretty:
+ text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}"
return text
@@ -1036,7 +1067,9 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
- return f"UNNEST({args}){ordinality}{alias}"
+ offset = expression.args.get("offset")
+ offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
+ return f"UNNEST({args}){ordinality}{alias}{offset}"
def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
@@ -1132,15 +1165,14 @@ class Generator:
return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression: exp.Trim) -> str:
- target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
- return f"LTRIM({target})"
+ return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
elif trim_type == "TRAILING":
- return f"RTRIM({target})"
+ return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
else:
- return f"TRIM({target})"
+ return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
@@ -1317,6 +1349,10 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT"
+ def renametable_sql(self, expression: exp.RenameTable) -> str:
+ this = self.sql(expression, "this")
+ return f"RENAME TO {this}"
+
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
@@ -1326,7 +1362,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
- elif isinstance(actions[0], exp.AlterColumn):
+ elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")