summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r--sqlglot/dialects/postgres.py55
1 files changed, 36 insertions, 19 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 126261e..c78f8a3 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -6,10 +6,12 @@ from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
DATE_ADD_OR_SUB,
Dialect,
+ JSON_EXTRACT_TYPE,
any_value_to_max_sql,
bool_xor_sql,
datestrtodate_sql,
- format_time_lambda,
+ build_formatted_time,
+ filter_array_using_unnest,
json_extract_segments,
json_path_key_only_name,
max_or_greatest,
@@ -20,8 +22,8 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_pivot_sql,
no_trycast_sql,
- parse_json_extract_path,
- parse_timestamp_trunc,
+ build_json_extract_path,
+ build_timestamp_trunc,
rename_func,
str_position_sql,
struct_extract_sql,
@@ -163,7 +165,7 @@ def _serial_to_generated(expression: exp.Expression) -> exp.Expression:
return expression
-def _generate_series(args: t.List) -> exp.Expression:
+def _build_generate_series(args: t.List) -> exp.GenerateSeries:
# The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day
step = seq_get(args, 2)
@@ -179,14 +181,25 @@ def _generate_series(args: t.List) -> exp.Expression:
return exp.GenerateSeries.from_arg_list(args)
-def _to_timestamp(args: t.List) -> exp.Expression:
+def _build_to_timestamp(args: t.List) -> exp.UnixToTime | exp.StrToTime:
# TO_TIMESTAMP accepts either a single double argument or (text, text)
if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
- return format_time_lambda(exp.StrToTime, "postgres")(args)
+ return build_formatted_time(exp.StrToTime, "postgres")(args)
+
+
+def _json_extract_sql(
+ name: str, op: str
+) -> t.Callable[[Postgres.Generator, JSON_EXTRACT_TYPE], str]:
+ def _generate(self: Postgres.Generator, expression: JSON_EXTRACT_TYPE) -> str:
+ if expression.args.get("only_json_types"):
+ return json_extract_segments(name, quoted_index=False, op=op)(self, expression)
+ return json_extract_segments(name)(self, expression)
+
+ return _generate
class Postgres(Dialect):
@@ -292,19 +305,19 @@ class Postgres(Dialect):
**parser.Parser.PROPERTY_PARSERS,
"SET": lambda self: self.expression(exp.SetConfigProperty, this=self._parse_set()),
}
- PROPERTY_PARSERS.pop("INPUT", None)
+ PROPERTY_PARSERS.pop("INPUT")
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
- "DATE_TRUNC": parse_timestamp_trunc,
- "GENERATE_SERIES": _generate_series,
- "JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract),
- "JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar),
+ "DATE_TRUNC": build_timestamp_trunc,
+ "GENERATE_SERIES": _build_generate_series,
+ "JSON_EXTRACT_PATH": build_json_extract_path(exp.JSONExtract),
+ "JSON_EXTRACT_PATH_TEXT": build_json_extract_path(exp.JSONExtractScalar),
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
- "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
- "TO_TIMESTAMP": _to_timestamp,
+ "TO_CHAR": build_formatted_time(exp.TimeToStr, "postgres"),
+ "TO_TIMESTAMP": _build_to_timestamp,
"UNNEST": exp.Explode.from_arg_list,
}
@@ -338,6 +351,8 @@ class Postgres(Dialect):
TokenType.END: lambda self: self._parse_commit_or_rollback(),
}
+ JSON_ARROWS_REQUIRE_JSON_TYPE = True
+
def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
while True:
if not self._match(TokenType.L_PAREN):
@@ -387,6 +402,7 @@ class Postgres(Dialect):
SUPPORTS_UNLOGGED_TABLES = True
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
+ CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
@@ -416,6 +432,8 @@ class Postgres(Dialect):
exp.ArrayContained: lambda self, e: self.binary(e, "<@"),
exp.ArrayContains: lambda self, e: self.binary(e, "@>"),
exp.ArrayOverlaps: lambda self, e: self.binary(e, "&&"),
+ exp.ArrayFilter: filter_array_using_unnest,
+ exp.ArraySize: lambda self, e: self.func("ARRAY_LENGTH", e.this, e.expression or "1"),
exp.BitwiseXor: lambda self, e: self.binary(e, "#"),
exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]),
exp.CurrentDate: no_paren_current_date_sql,
@@ -428,8 +446,8 @@ class Postgres(Dialect):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
- exp.JSONExtract: json_extract_segments("JSON_EXTRACT_PATH"),
- exp.JSONExtractScalar: json_extract_segments("JSON_EXTRACT_PATH_TEXT"),
+ exp.JSONExtract: _json_extract_sql("JSON_EXTRACT_PATH", "->"),
+ exp.JSONExtractScalar: _json_extract_sql("JSON_EXTRACT_PATH_TEXT", "->>"),
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
@@ -462,21 +480,20 @@ class Postgres(Dialect):
]
),
exp.StrPosition: str_position_sql,
- exp.StrToTime: lambda self,
- e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.StructExtract: struct_extract_sql,
exp.Substring: _substring_sql,
exp.TimeFromParts: rename_func("MAKE_TIME"),
exp.TimestampFromParts: rename_func("MAKE_TIMESTAMP"),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
- exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
exp.TsOrDsAdd: _date_add_sql("+"),
exp.TsOrDsDiff: _date_diff_sql,
- exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
+ exp.UnixToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this),
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,