summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/spark2.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/spark2.py')
-rw-r--r--sqlglot/dialects/spark2.py61
1 files changed, 30 insertions, 31 deletions
diff --git a/sqlglot/dialects/spark2.py b/sqlglot/dialects/spark2.py
index f909e8c..dcaa524 100644
--- a/sqlglot/dialects/spark2.py
+++ b/sqlglot/dialects/spark2.py
@@ -2,9 +2,11 @@ from __future__ import annotations
import typing as t
-from sqlglot import exp, parser, transforms
+from sqlglot import exp, transforms
from sqlglot.dialects.dialect import (
+ binary_from_function,
create_with_partitions_sql,
+ format_time_lambda,
pivot_column_names,
rename_func,
trim_sql,
@@ -108,47 +110,36 @@ class Spark2(Hive):
class Parser(Hive.Parser):
FUNCTIONS = {
**Hive.Parser.FUNCTIONS,
- "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
- "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
- "SHIFTLEFT": lambda args: exp.BitwiseLeftShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "SHIFTRIGHT": lambda args: exp.BitwiseRightShift(
- this=seq_get(args, 0),
- expression=seq_get(args, 1),
- ),
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
- "IIF": exp.If.from_arg_list,
"AGGREGATE": exp.Reduce.from_arg_list,
- "DAYOFWEEK": lambda args: exp.DayOfWeek(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFMONTH": lambda args: exp.DayOfMonth(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DAYOFYEAR": lambda args: exp.DayOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "WEEKOFYEAR": lambda args: exp.WeekOfYear(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- ),
- "DATE_TRUNC": lambda args: exp.TimestampTrunc(
- this=seq_get(args, 1),
- unit=exp.var(seq_get(args, 0)),
- ),
- "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
"BOOLEAN": _parse_as_cast("boolean"),
"DATE": _parse_as_cast("date"),
+ "DATE_TRUNC": lambda args: exp.TimestampTrunc(
+ this=seq_get(args, 1), unit=exp.var(seq_get(args, 0))
+ ),
+ "DAYOFMONTH": lambda args: exp.DayOfMonth(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFWEEK": lambda args: exp.DayOfWeek(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
+ "DAYOFYEAR": lambda args: exp.DayOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"DOUBLE": _parse_as_cast("double"),
"FLOAT": _parse_as_cast("float"),
+ "IIF": exp.If.from_arg_list,
"INT": _parse_as_cast("int"),
+ "MAP_FROM_ARRAYS": exp.Map.from_arg_list,
+ "RLIKE": exp.RegexpLike.from_arg_list,
+ "SHIFTLEFT": binary_from_function(exp.BitwiseLeftShift),
+ "SHIFTRIGHT": binary_from_function(exp.BitwiseRightShift),
"STRING": _parse_as_cast("string"),
"TIMESTAMP": _parse_as_cast("timestamp"),
+ "TO_TIMESTAMP": lambda args: _parse_as_cast("timestamp")(args)
+ if len(args) == 1
+ else format_time_lambda(exp.StrToTime, "spark")(args),
+ "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
+ "TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
+ "WEEKOFYEAR": lambda args: exp.WeekOfYear(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
}
FUNCTION_PARSERS = {
- **parser.Parser.FUNCTION_PARSERS,
+ **Hive.Parser.FUNCTION_PARSERS,
"BROADCAST": lambda self: self._parse_join_hint("BROADCAST"),
"BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"),
"MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"),
@@ -207,6 +198,13 @@ class Spark2(Hive):
exp.Map: _map_sql,
exp.Pivot: transforms.preprocess([_unqualify_pivot_columns]),
exp.Reduce: rename_func("AGGREGATE"),
+ exp.RegexpReplace: lambda self, e: self.func(
+ "REGEXP_REPLACE",
+ e.this,
+ e.expression,
+ e.args["replacement"],
+ e.args.get("position"),
+ ),
exp.StrToDate: _str_to_date,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimestampTrunc: lambda self, e: self.func(
@@ -224,6 +222,7 @@ class Spark2(Hive):
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
+ TRANSFORMS.pop(exp.MonthsBetween)
TRANSFORMS.pop(exp.Right)
WRAP_DERIVED_VALUES = False