diff options
Diffstat (limited to 'sqlglot/generator.py')
-rw-r--r-- | sqlglot/generator.py | 80 |
1 files changed, 72 insertions, 8 deletions
diff --git a/sqlglot/generator.py b/sqlglot/generator.py index a6f4772..6871dd8 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -110,6 +110,10 @@ class Generator: # Whether or not MERGE ... WHEN MATCHED BY SOURCE is allowed MATCHED_BY_SOURCE = True + # Whether or not limit and fetch are supported + # "ALL", "LIMIT", "FETCH" + LIMIT_FETCH = "ALL" + TYPE_MAPPING = { exp.DataType.Type.NCHAR: "CHAR", exp.DataType.Type.NVARCHAR: "VARCHAR", @@ -209,6 +213,7 @@ class Generator: "_leading_comma", "_max_text_width", "_comments", + "_cache", ) def __init__( @@ -265,19 +270,28 @@ class Generator: self._leading_comma = leading_comma self._max_text_width = max_text_width self._comments = comments + self._cache = None - def generate(self, expression: t.Optional[exp.Expression]) -> str: + def generate( + self, + expression: t.Optional[exp.Expression], + cache: t.Optional[t.Dict[int, str]] = None, + ) -> str: """ Generates a SQL string by interpreting the given syntax tree. Args expression: the syntax tree. + cache: an optional sql string cache. this leverages the hash of an expression which is slow, so only use this if you set _hash on each node. Returns the SQL string. """ + if cache is not None: + self._cache = cache self.unsupported_messages = [] sql = self.sql(expression).strip() + self._cache = None if self.unsupported_level == ErrorLevel.IGNORE: return sql @@ -387,6 +401,12 @@ class Generator: if key: return self.sql(expression.args.get(key)) + if self._cache is not None: + expression_id = hash(expression) + + if expression_id in self._cache: + return self._cache[expression_id] + transform = self.TRANSFORMS.get(expression.__class__) if callable(transform): @@ -407,7 +427,11 @@ class Generator: else: raise ValueError(f"Expected an Expression. Received {type(expression)}: {expression}") - return self.maybe_comment(sql, expression) if self._comments and comment else sql + sql = self.maybe_comment(sql, expression) if self._comments and comment else sql + + if self._cache is not None: + self._cache[expression_id] = sql + return sql def uncache_sql(self, expression: exp.Uncache) -> str: table = self.sql(expression, "this") @@ -697,7 +721,8 @@ class Generator: temporary = " TEMPORARY" if expression.args.get("temporary") else "" materialized = " MATERIALIZED" if expression.args.get("materialized") else "" cascade = " CASCADE" if expression.args.get("cascade") else "" - return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}" + constraints = " CONSTRAINTS" if expression.args.get("constraints") else "" + return f"DROP{temporary}{materialized} {kind}{exists_sql}{this}{cascade}{constraints}" def except_sql(self, expression: exp.Except) -> str: return self.prepend_ctes( @@ -733,9 +758,9 @@ class Generator: def identifier_sql(self, expression: exp.Identifier) -> str: text = expression.name - text = text.lower() if self.normalize else text + text = text.lower() if self.normalize and not expression.quoted else text text = text.replace(self.identifier_end, self._escaped_identifier_end) - if expression.args.get("quoted") or should_identify(text, self.identify): + if expression.quoted or should_identify(text, self.identify): text = f"{self.identifier_start}{text}{self.identifier_end}" return text @@ -1191,6 +1216,9 @@ class Generator: ) return f"SET{expressions}" + def pragma_sql(self, expression: exp.Pragma) -> str: + return f"PRAGMA {self.sql(expression, 'this')}" + def lock_sql(self, expression: exp.Lock) -> str: if self.LOCKING_READS_SUPPORTED: lock_type = "UPDATE" if expression.args["update"] else "SHARE" @@ -1299,6 +1327,15 @@ class Generator: return f"{self.seg('MATCH_RECOGNIZE')} {self.wrap(body)}" def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str: + limit = expression.args.get("limit") + + if self.LIMIT_FETCH == "LIMIT" and isinstance(limit, exp.Fetch): + limit = exp.Limit(expression=limit.args.get("count")) + elif self.LIMIT_FETCH == "FETCH" and isinstance(limit, exp.Limit): + limit = exp.Fetch(direction="FIRST", count=limit.expression) + + fetch = isinstance(limit, exp.Fetch) + return csv( *sqls, *[self.sql(sql) for sql in expression.args.get("joins") or []], @@ -1315,14 +1352,16 @@ class Generator: self.sql(expression, "sort"), self.sql(expression, "cluster"), self.sql(expression, "order"), - self.sql(expression, "limit"), - self.sql(expression, "offset"), + self.sql(expression, "offset") if fetch else self.sql(limit), + self.sql(limit) if fetch else self.sql(expression, "offset"), self.sql(expression, "lock"), self.sql(expression, "sample"), sep="", ) def select_sql(self, expression: exp.Select) -> str: + kind = expression.args.get("kind") + kind = f" AS {kind}" if kind else "" hint = self.sql(expression, "hint") distinct = self.sql(expression, "distinct") distinct = f" {distinct}" if distinct else "" @@ -1330,7 +1369,7 @@ class Generator: expressions = f"{self.sep()}{expressions}" if expressions else expressions sql = self.query_modifiers( expression, - f"SELECT{hint}{distinct}{expressions}", + f"SELECT{kind}{hint}{distinct}{expressions}", self.sql(expression, "into", comment=False), self.sql(expression, "from", comment=False), ) @@ -1552,6 +1591,25 @@ class Generator: exp.Case(ifs=[expression.copy()], default=expression.args.get("false")) ) + def jsonkeyvalue_sql(self, expression: exp.JSONKeyValue) -> str: + return f"{self.sql(expression, 'this')}: {self.sql(expression, 'expression')}" + + def jsonobject_sql(self, expression: exp.JSONObject) -> str: + expressions = self.expressions(expression) + null_handling = expression.args.get("null_handling") + null_handling = f" {null_handling}" if null_handling else "" + unique_keys = expression.args.get("unique_keys") + if unique_keys is not None: + unique_keys = f" {'WITH' if unique_keys else 'WITHOUT'} UNIQUE KEYS" + else: + unique_keys = "" + return_type = self.sql(expression, "return_type") + return_type = f" RETURNING {return_type}" if return_type else "" + format_json = " FORMAT JSON" if expression.args.get("format_json") else "" + encoding = self.sql(expression, "encoding") + encoding = f" ENCODING {encoding}" if encoding else "" + return f"JSON_OBJECT({expressions}{null_handling}{unique_keys}{return_type}{format_json}{encoding})" + def in_sql(self, expression: exp.In) -> str: query = expression.args.get("query") unnest = expression.args.get("unnest") @@ -1808,12 +1866,18 @@ class Generator: def ilike_sql(self, expression: exp.ILike) -> str: return self.binary(expression, "ILIKE") + def ilikeany_sql(self, expression: exp.ILikeAny) -> str: + return self.binary(expression, "ILIKE ANY") + def is_sql(self, expression: exp.Is) -> str: return self.binary(expression, "IS") def like_sql(self, expression: exp.Like) -> str: return self.binary(expression, "LIKE") + def likeany_sql(self, expression: exp.LikeAny) -> str: + return self.binary(expression, "LIKE ANY") + def similarto_sql(self, expression: exp.SimilarTo) -> str: return self.binary(expression, "SIMILAR TO") |