diff options
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r-- | sqlglot/dialects/dialect.py | 34 |
1 files changed, 31 insertions, 3 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 25490cb..b267521 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -8,7 +8,7 @@ from sqlglot.generator import Generator from sqlglot.helper import flatten, seq_get from sqlglot.parser import Parser from sqlglot.time import format_time -from sqlglot.tokens import Tokenizer +from sqlglot.tokens import Token, Tokenizer from sqlglot.trie import new_trie E = t.TypeVar("E", bound=exp.Expression) @@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect): return expression def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql) + return self.parser(**opts).parse(self.tokenize(sql), sql) def parse_into( self, expression_type: exp.IntoType, sql: str, **opts ) -> t.List[t.Optional[exp.Expression]]: - return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql) + return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) def generate(self, expression: t.Optional[exp.Expression], **opts) -> str: return self.generator(**opts).generate(expression) @@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect): def transpile(self, sql: str, **opts) -> t.List[str]: return [self.generate(expression, **opts) for expression in self.parse(sql)] + def tokenize(self, sql: str) -> t.List[Token]: + return self.tokenizer.tokenize(sql) + @property def tokenizer(self) -> Tokenizer: if not hasattr(self, "_tokenizer"): @@ -385,6 +388,21 @@ def parse_date_delta( return inner_func +def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc: + unit = seq_get(args, 0) + this = seq_get(args, 1) + + if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE): + return exp.DateTrunc(unit=unit, this=this) + return exp.TimestampTrunc(this=this, unit=unit) + + +def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: + return self.func( + "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this + ) + + def locate_to_strposition(args: t.Sequence) -> exp.Expression: return exp.StrPosition( this=seq_get(args, 1), @@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str: return rename_func(name)(self, expression) +def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: + cond = expression.this + + if isinstance(expression.this, exp.Distinct): + cond = expression.this.expressions[0] + self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") + + return self.func("sum", exp.func("if", cond, 1, 0)) + + def trim_sql(self: Generator, expression: exp.Trim) -> str: target = self.sql(expression, "this") trim_type = self.sql(expression, "position") |