summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark.py')
-rw-r--r--sqlglot/dialects/spark.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 03ec211..dd3e0c8 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -86,6 +86,11 @@ class Spark(Hive):
"WEEKOFYEAR": lambda args: exp.WeekOfYear(
this=exp.TsOrDsToDate(this=seq_get(args, 0)),
),
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1),
+ unit=exp.var(seq_get(args, 0)),
+ ),
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
}
FUNCTION_PARSERS = {
@@ -133,7 +138,7 @@ class Spark(Hive):
exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.BitwiseLeftShift: rename_func("SHIFTLEFT"),
exp.BitwiseRightShift: rename_func("SHIFTRIGHT"),
- exp.DateTrunc: rename_func("TRUNC"),
+ exp.DateTrunc: lambda self, e: self.func("TRUNC", e.this, e.args.get("unit")),
exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */",
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
@@ -142,7 +147,9 @@ class Spark(Hive):
exp.Map: _map_sql,
exp.Reduce: rename_func("AGGREGATE"),
exp.StructKwarg: lambda self, e: f"{self.sql(e, 'this')}: {self.sql(e, 'expression')}",
- exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
+ exp.TimestampTrunc: lambda self, e: self.func(
+ "DATE_TRUNC", exp.Literal.string(e.text("unit")), e.this
+ ),
exp.Trim: trim_sql,
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
@@ -157,16 +164,16 @@ class Spark(Hive):
TRANSFORMS.pop(exp.ILike)
WRAP_DERIVED_VALUES = False
- CREATE_FUNCTION_AS = False
+ CREATE_FUNCTION_RETURN_AS = False
def cast_sql(self, expression: exp.Cast) -> str:
if isinstance(expression.this, exp.Cast) and expression.this.is_type(
exp.DataType.Type.JSON
):
schema = f"'{self.sql(expression, 'to')}'"
- return f"FROM_JSON({self.format_args(self.sql(expression.this, 'this'), schema)})"
+ return self.func("FROM_JSON", expression.this.this, schema)
if expression.to.is_type(exp.DataType.Type.JSON):
- return f"TO_JSON({self.sql(expression, 'this')})"
+ return self.func("TO_JSON", expression.this)
return super(Spark.Generator, self).cast_sql(expression)