diff options
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r-- | sqlglot/dialects/hive.py | 64 |
1 files changed, 30 insertions, 34 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index c39656e..6746fcf 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = { DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH") -def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str: +def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: unit = expression.text("unit").upper() func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1)) - modified_increment = ( - int(expression.text("expression")) * multiplier - if expression.expression.is_number - else expression.expression - ) - modified_increment = exp.Literal.number(modified_increment) - return self.func(func, expression.this, modified_increment.this) + + if isinstance(expression, exp.DateSub): + multiplier *= -1 + + if expression.expression.is_number: + modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier) + else: + modified_increment = expression.expression + if multiplier != 1: + modified_increment = exp.Mul( # type: ignore + this=modified_increment, expression=exp.Literal.number(multiplier) + ) + + return self.func(func, expression.this, modified_increment) def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: @@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str return f"TO_DATE({this})" -def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str: - unnest = expression.this - if isinstance(unnest, exp.Unnest): - alias = unnest.args.get("alias") - udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode - return "".join( - self.sql( - exp.Lateral( - this=udtf(this=expression), - view=True, - alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore - ) - ) - for expression, column in zip(unnest.expressions, alias.columns if alias else []) - ) - return self.join_sql(expression) - - def _index_sql(self: generator.Generator, expression: exp.Index) -> str: this = self.sql(expression, "this") table = self.sql(expression, "table") @@ -195,6 +184,7 @@ class Hive(Dialect): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] ENCODE = "utf-8" + IDENTIFIER_CAN_START_WITH_DIGIT = True KEYWORDS = { **tokens.Tokenizer.KEYWORDS, @@ -217,9 +207,8 @@ class Hive(Dialect): "BD": "DECIMAL", } - IDENTIFIER_CAN_START_WITH_DIGIT = True - class Parser(parser.Parser): + LOG_DEFAULTS_TO_LN = True STRICT_CAST = False FUNCTIONS = { @@ -273,9 +262,13 @@ class Hive(Dialect): ), } - LOG_DEFAULTS_TO_LN = True - class Generator(generator.Generator): + LIMIT_FETCH = "LIMIT" + TABLESAMPLE_WITH_METHOD = False + TABLESAMPLE_SIZE_IS_PERCENT = True + JOIN_HINTS = False + TABLE_HINTS = False + TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TEXT: "STRING", @@ -289,6 +282,9 @@ class Hive(Dialect): **generator.Generator.TRANSFORMS, # type: ignore **transforms.UNALIAS_GROUP, # type: ignore **transforms.ELIMINATE_QUALIFY, # type: ignore + exp.Select: transforms.preprocess( + [transforms.eliminate_qualify, transforms.unnest_to_explode] + ), exp.Property: _property_sql, exp.ApproxDistinct: approx_count_distinct_sql, exp.ArrayConcat: rename_func("CONCAT"), @@ -298,13 +294,13 @@ class Hive(Dialect): exp.DateAdd: _add_date_sql, exp.DateDiff: _date_diff_sql, exp.DateStrToDate: rename_func("TO_DATE"), + exp.DateSub: _add_date_sql, exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)", exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})", - exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}", + exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}", exp.If: if_sql, exp.Index: _index_sql, exp.ILike: no_ilike_sql, - exp.Join: _unnest_to_explode_sql, exp.JSONExtract: rename_func("GET_JSON_OBJECT"), exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"), exp.JSONFormat: rename_func("TO_JSON"), @@ -354,10 +350,9 @@ class Hive(Dialect): exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA, exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA, exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA, + exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "LIMIT" - def arrayagg_sql(self, expression: exp.ArrayAgg) -> str: return self.func( "COLLECT_LIST", @@ -378,4 +373,5 @@ class Hive(Dialect): expression = exp.DataType.build("text") elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) + return super().datatype_sql(expression) |