summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/dialect.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/dialect.py')
-rw-r--r--sqlglot/dialects/dialect.py34
1 files changed, 31 insertions, 3 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 25490cb..b267521 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -8,7 +8,7 @@ from sqlglot.generator import Generator
from sqlglot.helper import flatten, seq_get
from sqlglot.parser import Parser
from sqlglot.time import format_time
-from sqlglot.tokens import Tokenizer
+from sqlglot.tokens import Token, Tokenizer
from sqlglot.trie import new_trie
E = t.TypeVar("E", bound=exp.Expression)
@@ -160,12 +160,12 @@ class Dialect(metaclass=_Dialect):
return expression
def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
- return self.parser(**opts).parse(self.tokenizer.tokenize(sql), sql)
+ return self.parser(**opts).parse(self.tokenize(sql), sql)
def parse_into(
self, expression_type: exp.IntoType, sql: str, **opts
) -> t.List[t.Optional[exp.Expression]]:
- return self.parser(**opts).parse_into(expression_type, self.tokenizer.tokenize(sql), sql)
+ return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql)
def generate(self, expression: t.Optional[exp.Expression], **opts) -> str:
return self.generator(**opts).generate(expression)
@@ -173,6 +173,9 @@ class Dialect(metaclass=_Dialect):
def transpile(self, sql: str, **opts) -> t.List[str]:
return [self.generate(expression, **opts) for expression in self.parse(sql)]
+ def tokenize(self, sql: str) -> t.List[Token]:
+ return self.tokenizer.tokenize(sql)
+
@property
def tokenizer(self) -> Tokenizer:
if not hasattr(self, "_tokenizer"):
@@ -385,6 +388,21 @@ def parse_date_delta(
return inner_func
+def date_trunc_to_time(args: t.Sequence) -> exp.DateTrunc | exp.TimestampTrunc:
+ unit = seq_get(args, 0)
+ this = seq_get(args, 1)
+
+ if isinstance(this, exp.Cast) and this.is_type(exp.DataType.Type.DATE):
+ return exp.DateTrunc(unit=unit, this=this)
+ return exp.TimestampTrunc(this=this, unit=unit)
+
+
+def timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str:
+ return self.func(
+ "DATE_TRUNC", exp.Literal.string(expression.text("unit") or "day"), expression.this
+ )
+
+
def locate_to_strposition(args: t.Sequence) -> exp.Expression:
return exp.StrPosition(
this=seq_get(args, 1),
@@ -412,6 +430,16 @@ def min_or_least(self: Generator, expression: exp.Min) -> str:
return rename_func(name)(self, expression)
+def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str:
+ cond = expression.this
+
+ if isinstance(expression.this, exp.Distinct):
+ cond = expression.this.expressions[0]
+ self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM")
+
+ return self.func("sum", exp.func("if", cond, 1, 0))
+
+
def trim_sql(self: Generator, expression: exp.Trim) -> str:
target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")