summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/hive.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r--sqlglot/dialects/hive.py64
1 files changed, 30 insertions, 34 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index c39656e..6746fcf 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -45,16 +45,23 @@ TIME_DIFF_FACTOR = {
DIFF_MONTH_SWITCH = ("YEAR", "QUARTER", "MONTH")
-def _add_date_sql(self: generator.Generator, expression: exp.DateAdd) -> str:
+def _add_date_sql(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
unit = expression.text("unit").upper()
func, multiplier = DATE_DELTA_INTERVAL.get(unit, ("DATE_ADD", 1))
- modified_increment = (
- int(expression.text("expression")) * multiplier
- if expression.expression.is_number
- else expression.expression
- )
- modified_increment = exp.Literal.number(modified_increment)
- return self.func(func, expression.this, modified_increment.this)
+
+ if isinstance(expression, exp.DateSub):
+ multiplier *= -1
+
+ if expression.expression.is_number:
+ modified_increment = exp.Literal.number(int(expression.text("expression")) * multiplier)
+ else:
+ modified_increment = expression.expression
+ if multiplier != 1:
+ modified_increment = exp.Mul( # type: ignore
+ this=modified_increment, expression=exp.Literal.number(multiplier)
+ )
+
+ return self.func(func, expression.this, modified_increment)
def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str:
@@ -127,24 +134,6 @@ def _to_date_sql(self: generator.Generator, expression: exp.TsOrDsToDate) -> str
return f"TO_DATE({this})"
-def _unnest_to_explode_sql(self: generator.Generator, expression: exp.Join) -> str:
- unnest = expression.this
- if isinstance(unnest, exp.Unnest):
- alias = unnest.args.get("alias")
- udtf = exp.Posexplode if unnest.args.get("ordinality") else exp.Explode
- return "".join(
- self.sql(
- exp.Lateral(
- this=udtf(this=expression),
- view=True,
- alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore
- )
- )
- for expression, column in zip(unnest.expressions, alias.columns if alias else [])
- )
- return self.join_sql(expression)
-
-
def _index_sql(self: generator.Generator, expression: exp.Index) -> str:
this = self.sql(expression, "this")
table = self.sql(expression, "table")
@@ -195,6 +184,7 @@ class Hive(Dialect):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
ENCODE = "utf-8"
+ IDENTIFIER_CAN_START_WITH_DIGIT = True
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
@@ -217,9 +207,8 @@ class Hive(Dialect):
"BD": "DECIMAL",
}
- IDENTIFIER_CAN_START_WITH_DIGIT = True
-
class Parser(parser.Parser):
+ LOG_DEFAULTS_TO_LN = True
STRICT_CAST = False
FUNCTIONS = {
@@ -273,9 +262,13 @@ class Hive(Dialect):
),
}
- LOG_DEFAULTS_TO_LN = True
-
class Generator(generator.Generator):
+ LIMIT_FETCH = "LIMIT"
+ TABLESAMPLE_WITH_METHOD = False
+ TABLESAMPLE_SIZE_IS_PERCENT = True
+ JOIN_HINTS = False
+ TABLE_HINTS = False
+
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TEXT: "STRING",
@@ -289,6 +282,9 @@ class Hive(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
**transforms.UNALIAS_GROUP, # type: ignore
**transforms.ELIMINATE_QUALIFY, # type: ignore
+ exp.Select: transforms.preprocess(
+ [transforms.eliminate_qualify, transforms.unnest_to_explode]
+ ),
exp.Property: _property_sql,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArrayConcat: rename_func("CONCAT"),
@@ -298,13 +294,13 @@ class Hive(Dialect):
exp.DateAdd: _add_date_sql,
exp.DateDiff: _date_diff_sql,
exp.DateStrToDate: rename_func("TO_DATE"),
+ exp.DateSub: _add_date_sql,
exp.DateToDi: lambda self, e: f"CAST(DATE_FORMAT({self.sql(e, 'this')}, {Hive.dateint_format}) AS INT)",
exp.DiToDate: lambda self, e: f"TO_DATE(CAST({self.sql(e, 'this')} AS STRING), {Hive.dateint_format})",
- exp.FileFormatProperty: lambda self, e: f"STORED AS {e.name.upper()}",
+ exp.FileFormatProperty: lambda self, e: f"STORED AS {self.sql(e, 'this') if isinstance(e.this, exp.InputOutputFormat) else e.name.upper()}",
exp.If: if_sql,
exp.Index: _index_sql,
exp.ILike: no_ilike_sql,
- exp.Join: _unnest_to_explode_sql,
exp.JSONExtract: rename_func("GET_JSON_OBJECT"),
exp.JSONExtractScalar: rename_func("GET_JSON_OBJECT"),
exp.JSONFormat: rename_func("TO_JSON"),
@@ -354,10 +350,9 @@ class Hive(Dialect):
exp.FileFormatProperty: exp.Properties.Location.POST_SCHEMA,
exp.PartitionedByProperty: exp.Properties.Location.POST_SCHEMA,
exp.TableFormatProperty: exp.Properties.Location.POST_SCHEMA,
+ exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- LIMIT_FETCH = "LIMIT"
-
def arrayagg_sql(self, expression: exp.ArrayAgg) -> str:
return self.func(
"COLLECT_LIST",
@@ -378,4 +373,5 @@ class Hive(Dialect):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
+
return super().datatype_sql(expression)