diff options
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 05e81ce..1d0584c 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -346,7 +346,9 @@ def inline_array_sql(self: Generator, expression: exp.Array) -> str: def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: return self.like_sql( - exp.Like(this=exp.Lower(this=expression.this), expression=expression.expression) + exp.Like( + this=exp.Lower(this=expression.this.copy()), expression=expression.expression.copy() + ) ) @@ -410,7 +412,7 @@ def str_position_sql(self: Generator, expression: exp.StrPosition) -> str: def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: this = self.sql(expression, "this") - struct_key = self.sql(exp.Identifier(this=expression.expression, quoted=True)) + struct_key = self.sql(exp.Identifier(this=expression.expression.copy(), quoted=True)) return f"{this}.{struct_key}" @@ -571,6 +573,17 @@ def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: return self.sql(exp.cast(expression.this, "date")) +# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 +def encode_decode_sql( + self: Generator, expression: exp.Expression, name: str, replace: bool = True +) -> str: + charset = expression.args.get("charset") + if charset and charset.name.lower() != "utf-8": + self.unsupported(f"Expected utf-8 character set, got {charset}.") + + return self.func(name, expression.this, expression.args.get("replace") if replace else None) + + def min_or_least(self: Generator, expression: exp.Min) -> str: name = "LEAST" if expression.expressions else "MIN" return rename_func(name)(self, expression) @@ -588,7 +601,7 @@ def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: cond = expression.this.expressions[0] self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") - return self.func("sum", exp.func("if", cond, 1, 0)) + return self.func("sum", exp.func("if", cond.copy(), 1, 0)) def trim_sql(self: Generator, expression: exp.Trim) -> str: @@ -625,6 +638,7 @@ def ts_or_ds_to_date_sql(dialect: str) -> t.Callable: def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat) -> str: + expression = expression.copy() this, *rest_args = expression.expressions for arg in rest_args: this = exp.DPipe(this=this, expression=arg) @@ -674,11 +688,10 @@ def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectTyp return names -def simplify_literal(expression: E, copy: bool = True) -> E: +def simplify_literal(expression: E) -> E: if not isinstance(expression.expression, exp.Literal): from sqlglot.optimizer.simplify import simplify - expression = exp.maybe_copy(expression, copy) simplify(expression.expression) return expression |