summaryrefslogtreecommitdiffstats
path: root/sqlglot/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r--sqlglot/generator.py80
1 files changed, 72 insertions, 8 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index a6f4772..6871dd8 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -110,6 +110,10 @@ class Generator:
# Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed
MATCHED_BY_SOURCE = True
+ # Whether or not limit and fetch are supported
+ # "ALL", "LIMIT", "FETCH"
+ LIMIT_FETCH = "ALL"
+
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
@@ -209,6 +213,7 @@ class Generator:
"_leading_comma",
"_max_text_width",
"_comments",
+ "_cache",
)
def __init__(
@@ -265,19 +270,28 @@ class Generator:
self._leading_comma = leading_comma
self._max_text_width = max_text_width
self._comments = comments
+ self._cache = None
- def generate(self, expression: t.Optional[exp.Expression]) -> str:
+ def generate(
+ self,
+ expression: t.Optional[exp.Expression],
+ cache: t.Optional[t.Dict[int, str]] = None,
+ ) -> str:
"""
Generates a SQL string by interpreting the given syntax tree.
Args
expression: the syntax tree.
+ cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node.
Returns
the SQL string.
"""
+ if cache is not None:
+ self._cache = cache
self.unsupported_messages = []
sql = self.sql(expression).strip()
+ self._cache = None
if self.unsupported_level == ErrorLevel.IGNORE:
return sql
@@ -387,6 +401,12 @@ class Generator:
if key:
return self.sql(expression.args.get(key))
+ if self._cache is not None:
+ expression_id = hash(expression)
+
+ if expression_id in self._cache:
+ return self._cache[expression_id]
+
transform = self.TRANSFORMS.get(expression.__class__)
if callable(transform):
@@ -407,7 +427,11 @@ class Generator:
else:
raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}")
- return self.maybe_comment(sql, expression) if self._comments and comment else sql
+ sql = self.maybe_comment(sql, expression) if self._comments and comment else sql
+
+ if self._cache is not None:
+ self._cache[expression_id] = sql
+ return sql
def uncache_sql(self, expression: exp.Uncache) -> str:
table = self.sql(expression, "this")
@@ -697,7 +721,8 @@ class Generator:
temporary = " TEMPORARY" if expression.args.get("temporary") else ""
materialized = " MATERIALIZED" if expression.args.get("materialized") else ""
cascade = " CASCADE" if expression.args.get("cascade") else ""
- return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}"
+ constraints = " CONSTRAINTS" if expression.args.get("constraints") else ""
+ return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}"
def except_sql(self, expression: exp.Except) -> str:
return self.prepend_ctes(
@@ -733,9 +758,9 @@ class Generator:
def identifier_sql(self, expression: exp.Identifier) -> str:
text = expression.name
- text = text.lower() if self.normalize else text
+ text = text.lower() if self.normalize and not expression.quoted else text
text = text.replace(self.identifier_end, self._escaped_identifier_end)
- if expression.args.get("quoted") or should_identify(text, self.identify):
+ if expression.quoted or should_identify(text, self.identify):
text = f"{self.identifier_start}{text}{self.identifier_end}"
return text
@@ -1191,6 +1216,9 @@ class Generator:
)
return f"SET{expressions}"
+ def pragma_sql(self, expression: exp.Pragma) -> str:
+ return f"PRAGMA {self.sql(expression, 'this')}"
+
def lock_sql(self, expression: exp.Lock) -> str:
if self.LOCKING_READS_SUPPORTED:
lock_type = "UPDATE" if expression.args["update"] else "SHARE"
@@ -1299,6 +1327,15 @@ class Generator:
return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}"
def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
+ limit = expression.args.get("limit")
+
+ if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch):
+ limit = exp.Limit(expression=limit.args.get("count"))
+ elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit):
+ limit = exp.Fetch(direction="FIRST", count=limit.expression)
+
+ fetch = isinstance(limit, exp.Fetch)
+
return csv(
*sqls,
*[self.sql(sql) for sql in expression.args.get("joins") or []],
@@ -1315,14 +1352,16 @@ class Generator:
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
self.sql(expression, "order"),
- self.sql(expression, "limit"),
- self.sql(expression, "offset"),
+ self.sql(expression, "offset") if fetch else self.sql(limit),
+ self.sql(limit) if fetch else self.sql(expression, "offset"),
self.sql(expression, "lock"),
self.sql(expression, "sample"),
sep="",
)
def select_sql(self, expression: exp.Select) -> str:
+ kind = expression.args.get("kind")
+ kind = f" AS {kind}" if kind else ""
hint = self.sql(expression, "hint")
distinct = self.sql(expression, "distinct")
distinct = f" {distinct}" if distinct else ""
@@ -1330,7 +1369,7 @@ class Generator:
expressions = f"{self.sep()}{expressions}" if expressions else expressions
sql = self.query_modifiers(
expression,
- f"SELECT{hint}{distinct}{expressions}",
+ f"SELECT{kind}{hint}{distinct}{expressions}",
self.sql(expression, "into", comment=False),
self.sql(expression, "from", comment=False),
)
@@ -1552,6 +1591,25 @@ class Generator:
exp.Case(ifs=[expression.copy()], default=expression.args.get("false"))
)
+ def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str:
+ return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}"
+
+ def jsonobject_sql(self, expression: exp.JSONObject) -> str:
+ expressions = self.expressions(expression)
+ null_handling = expression.args.get("null_handling")
+ null_handling = f" {null_handling}" if null_handling else ""
+ unique_keys = expression.args.get("unique_keys")
+ if unique_keys is not None:
+ unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS"
+ else:
+ unique_keys = ""
+ return_type = self.sql(expression, "return_type")
+ return_type = f" RETURNING {return_type}" if return_type else ""
+ format_json = " FORMAT JSON" if expression.args.get("format_json") else ""
+ encoding = self.sql(expression, "encoding")
+ encoding = f" ENCODING {encoding}" if encoding else ""
+ return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})"
+
def in_sql(self, expression: exp.In) -> str:
query = expression.args.get("query")
unnest = expression.args.get("unnest")
@@ -1808,12 +1866,18 @@ class Generator:
def ilike_sql(self, expression: exp.ILike) -> str:
return self.binary(expression, "ILIKE")
+ def ilikeany_sql(self, expression: exp.ILikeAny) -> str:
+ return self.binary(expression, "ILIKE ANY")
+
def is_sql(self, expression: exp.Is) -> str:
return self.binary(expression, "IS")
def like_sql(self, expression: exp.Like) -> str:
return self.binary(expression, "LIKE")
+ def likeany_sql(self, expression: exp.LikeAny) -> str:
+ return self.binary(expression, "LIKE ANY")
+
def similarto_sql(self, expression: exp.SimilarTo) -> str:
return self.binary(expression, "SIMILAR TO")