From d1f00706bff58b863b0a1c5bf4adf39d36049d4c Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Fri, 11 Nov 2022 09:54:35 +0100 Subject: Merging upstream version 10.0.1. Signed-off-by: Daniel Baumann --- sqlglot/dialects/hive.py | 57 +++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 25 deletions(-) (limited to 'sqlglot/dialects/hive.py') diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 03049ff..ed7357c 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -1,4 +1,6 @@ -from sqlglot import exp, transforms +from __future__ import annotations + +from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( Dialect, approx_count_distinct_sql, @@ -13,10 +15,8 @@ from sqlglot.dialects.dialect import ( struct_extract_sql, var_map_sql, ) -from sqlglot.generator import Generator -from sqlglot.helper import list_get -from sqlglot.parser import Parser, parse_var_map -from sqlglot.tokens import Tokenizer +from sqlglot.helper import seq_get +from sqlglot.parser import parse_var_map # (FuncType, Multiplier) DATE_DELTA_INTERVAL = { @@ -34,7 +34,9 @@ def _add_date_sql(self, expression): unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) modified_increment = ( - int(expression.text("expression")) * multiplier if expression.expression.is_number else expression.expression + int(expression.text("expression")) * multiplier + if expression.expression.is_number + else expression.expression ) modified_increment = exp.Literal.number(modified_increment) return f"{func}({self.format_args(expression.this, modified_increment.this)})" @@ -165,10 +167,10 @@ class Hive(Dialect): dateint_format = "'yyyyMMdd'" time_format = "'yyyy-MM-dd HH:mm:ss'" - class Tokenizer(Tokenizer): + class Tokenizer(tokens.Tokenizer): QUOTES = ["'", '"'] IDENTIFIERS = ["`"] - ESCAPE = "\\" + ESCAPES = ["\\"] ENCODE = "utf-8" NUMERIC_LITERALS = { @@ -180,40 +182,44 @@ class Hive(Dialect): "BD": "DECIMAL", } - class Parser(Parser): + class Parser(parser.Parser): STRICT_CAST = False FUNCTIONS = { - **Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, "APPROX_COUNT_DISTINCT": exp.ApproxDistinct.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), - expression=list_get(args, 1), + this=seq_get(args, 0), + expression=seq_get(args, 1), unit=exp.Literal.string("DAY"), ), "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=list_get(args, 0)), - expression=exp.TsOrDsToDate(this=list_get(args, 1)), + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DATE_SUB": lambda args: exp.TsOrDsAdd( - this=list_get(args, 0), + this=seq_get(args, 0), expression=exp.Mul( - this=list_get(args, 1), + this=seq_get(args, 1), expression=exp.Literal.number(-1), ), unit=exp.Literal.string("DAY"), ), "DATE_FORMAT": format_time_lambda(exp.TimeToStr, "hive"), - "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=list_get(args, 0))), + "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), "GET_JSON_OBJECT": exp.JSONExtractScalar.from_arg_list, "LOCATE": lambda args: exp.StrPosition( - this=list_get(args, 1), - substr=list_get(args, 0), - position=list_get(args, 2), + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), + ), + "LOG": ( + lambda args: exp.Log.from_arg_list(args) + if len(args) > 1 + else exp.Ln.from_arg_list(args) ), - "LOG": (lambda args: exp.Log.from_arg_list(args) if len(args) > 1 else exp.Ln.from_arg_list(args)), "MAP": parse_var_map, "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, @@ -226,15 +232,16 @@ class Hive(Dialect): "YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)), } - class Generator(Generator): + class Generator(generator.Generator): TYPE_MAPPING = { - **Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TEXT: "STRING", + exp.DataType.Type.VARBINARY: "BINARY", } TRANSFORMS = { - **Generator.TRANSFORMS, - **transforms.UNALIAS_GROUP, + **generator.Generator.TRANSFORMS, + **transforms.UNALIAS_GROUP, # type: ignore exp.AnonymousProperty: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayAgg: rename_func("COLLECT_LIST"), -- cgit v1.2.3