summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/hive.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/hive.py')
-rw-r--r--sqlglot/dialects/hive.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index 5762efb..f968f6a 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import (
no_recursive_cte_sql,
no_safe_divide_sql,
no_trycast_sql,
+ regexp_extract_sql,
rename_func,
right_to_substring_sql,
strposition_to_locate_sql,
@@ -230,23 +231,24 @@ class Hive(Dialect):
**parser.Parser.FUNCTIONS,
"BASE64": exp.ToBase64.from_arg_list,
"COLLECT_LIST": exp.ArrayAgg.from_arg_list,
+ "COLLECT_SET": exp.SetAgg.from_arg_list,
"DATE_ADD": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY")
),
- "DATEDIFF": lambda args: exp.DateDiff(
- this=exp.TsOrDsToDate(this=seq_get(args, 0)),
- expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
+ "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
+ [
+ exp.TimeStrToTime(this=seq_get(args, 0)),
+ seq_get(args, 1),
+ ]
),
"DATE_SUB": lambda args: exp.TsOrDsAdd(
this=seq_get(args, 0),
expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)),
unit=exp.Literal.string("DAY"),
),
- "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")(
- [
- exp.TimeStrToTime(this=seq_get(args, 0)),
- seq_get(args, 1),
- ]
+ "DATEDIFF": lambda args: exp.DateDiff(
+ this=exp.TsOrDsToDate(this=seq_get(args, 0)),
+ expression=exp.TsOrDsToDate(this=seq_get(args, 1)),
),
"DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))),
"FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True),
@@ -256,7 +258,9 @@ class Hive(Dialect):
"MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)),
"PERCENTILE": exp.Quantile.from_arg_list,
"PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list,
- "COLLECT_SET": exp.SetAgg.from_arg_list,
+ "REGEXP_EXTRACT": lambda args: exp.RegexpExtract(
+ this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2)
+ ),
"SIZE": exp.ArraySize.from_arg_list,
"SPLIT": exp.RegexpSplit.from_arg_list,
"TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"),
@@ -363,6 +367,7 @@ class Hive(Dialect):
exp.Create: create_with_partitions_sql,
exp.Quantile: rename_func("PERCENTILE"),
exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"),
+ exp.RegexpExtract: regexp_extract_sql,
exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"),
exp.RegexpSplit: rename_func("SPLIT"),
exp.Right: right_to_substring_sql,
@@ -422,5 +427,12 @@ class Hive(Dialect):
expression = exp.DataType.build("text")
elif expression.this in exp.DataType.TEMPORAL_TYPES:
expression = exp.DataType.build(expression.this)
+ elif expression.is_type("float"):
+ size_expression = expression.find(exp.DataTypeSize)
+ if size_expression:
+ size = int(size_expression.name)
+ expression = (
+ exp.DataType.build("float") if size <= 32 else exp.DataType.build("double")
+ )
return super().datatype_sql(expression)