From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- sqlglot/dialects/spark.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) (limited to 'sqlglot/dialects/spark.py') diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py index 572f411..4e404b8 100644 --- a/sqlglot/dialects/spark.py +++ b/sqlglot/dialects/spark.py @@ -1,8 +1,9 @@ -from sqlglot import exp +from __future__ import annotations + +from sqlglot import exp, parser from sqlglot.dialects.dialect import create_with_partitions_sql, rename_func from sqlglot.dialects.hive import Hive -from sqlglot.helper import list_get -from sqlglot.parser import Parser +from sqlglot.helper import seq_get def _create_sql(self, e): @@ -46,36 +47,36 @@ def _unix_to_time(self, expression): class Spark(Hive): class Parser(Hive.Parser): FUNCTIONS = { - **Hive.Parser.FUNCTIONS, + **Hive.Parser.FUNCTIONS, # type: ignore "MAP_FROM_ARRAYS": exp.Map.from_arg_list, "TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list, "LEFT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Literal.number(1), - length=list_get(args, 1), + length=seq_get(args, 1), ), "SHIFTLEFT": lambda args: exp.BitwiseLeftShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "SHIFTRIGHT": lambda args: exp.BitwiseRightShift( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), ), "RIGHT": lambda args: exp.Substring( - this=list_get(args, 0), + this=seq_get(args, 0), start=exp.Sub( - this=exp.Length(this=list_get(args, 0)), - expression=exp.Add(this=list_get(args, 1), expression=exp.Literal.number(1)), + this=exp.Length(this=seq_get(args, 0)), + expression=exp.Add(this=seq_get(args, 1), expression=exp.Literal.number(1)), ), - length=list_get(args, 1), + length=seq_get(args, 1), ), "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list, "IIF": exp.If.from_arg_list, } FUNCTION_PARSERS = { - **Parser.FUNCTION_PARSERS, + **parser.Parser.FUNCTION_PARSERS, "BROADCAST": lambda self: self._parse_join_hint("BROADCAST"), "BROADCASTJOIN": lambda self: self._parse_join_hint("BROADCASTJOIN"), "MAPJOIN": lambda self: self._parse_join_hint("MAPJOIN"), @@ -88,14 +89,14 @@ class Spark(Hive): class Generator(Hive.Generator): TYPE_MAPPING = { - **Hive.Generator.TYPE_MAPPING, + **Hive.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "BYTE", exp.DataType.Type.SMALLINT: "SHORT", exp.DataType.Type.BIGINT: "LONG", } TRANSFORMS = { - **{k: v for k, v in Hive.Generator.TRANSFORMS.items() if k not in {exp.ArraySort, exp.ILike}}, + **Hive.Generator.TRANSFORMS, # type: ignore exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.FileFormatProperty: lambda self, e: f"USING {e.text('value').upper()}", exp.ArraySum: lambda self, e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)", @@ -114,6 +115,8 @@ class Spark(Hive): exp.VariancePop: rename_func("VAR_POP"), exp.DateFromParts: rename_func("MAKE_DATE"), } + TRANSFORMS.pop(exp.ArraySort) + TRANSFORMS.pop(exp.ILike) WRAP_DERIVED_VALUES = False -- cgit v1.2.3