summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/postgres.py
diff options
context:
space:
mode:
authorDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
committerDaniel Baumann <mail@daniel-baumann.ch>2023-12-10 10:46:01 +0000
commit8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch)
tree6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/postgres.py
parentReleasing debian version 19.0.1-1. (diff)
downloadsqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz
sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r--sqlglot/dialects/postgres.py97
1 files changed, 64 insertions, 33 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 27c6851..fefddee 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import (
+ DATE_ADD_OR_SUB,
Dialect,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
@@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import (
timestamptrunc_sql,
timestrtotime_sql,
trim_sql,
+ ts_or_ds_add_cast,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
@@ -41,8 +43,11 @@ DATE_DIFF_FACTOR = {
}
-def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]:
- def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str:
+def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB], str]:
+ def func(self: Postgres.Generator, expression: DATE_ADD_OR_SUB) -> str:
+ if isinstance(expression, exp.TsOrDsAdd):
+ expression = ts_or_ds_add_cast(expression)
+
this = self.sql(expression, "this")
unit = expression.args.get("unit")
@@ -60,8 +65,8 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
unit = expression.text("unit").upper()
factor = DATE_DIFF_FACTOR.get(unit)
- end = f"CAST({expression.this} AS TIMESTAMP)"
- start = f"CAST({expression.expression} AS TIMESTAMP)"
+ end = f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)"
+ start = f"CAST({self.sql(expression, 'expression')} AS TIMESTAMP)"
if factor is not None:
return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
@@ -69,7 +74,7 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str:
age = f"AGE({end}, {start})"
if unit == "WEEK":
- unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
+ unit = f"EXTRACT(days FROM ({end} - {start})) / 7"
elif unit == "MONTH":
unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
elif unit == "QUARTER":
@@ -183,37 +188,43 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)
-def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
- """Remove table refs from columns in when statements."""
- if isinstance(expression, exp.Merge):
- alias = expression.this.args.get("alias")
+def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str:
+ def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
+ """Remove table refs from columns in when statements."""
+ if isinstance(expression, exp.Merge):
+ alias = expression.this.args.get("alias")
- normalize = (
- lambda identifier: Postgres.normalize_identifier(identifier).name
- if identifier
- else None
- )
+ normalize = (
+ lambda identifier: self.dialect.normalize_identifier(identifier).name
+ if identifier
+ else None
+ )
- targets = {normalize(expression.this.this)}
+ targets = {normalize(expression.this.this)}
- if alias:
- targets.add(normalize(alias.this))
+ if alias:
+ targets.add(normalize(alias.this))
- for when in expression.expressions:
- when.transform(
- lambda node: exp.column(node.this)
- if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
- else node,
- copy=False,
- )
+ for when in expression.expressions:
+ when.transform(
+ lambda node: exp.column(node.this)
+ if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
+ else node,
+ copy=False,
+ )
- return expression
+ return expression
+
+ return transforms.preprocess([_remove_target_from_merge])(self, expression)
class Postgres(Dialect):
INDEX_OFFSET = 1
+ TYPED_DIVISION = True
+ CONCAT_COALESCE = True
NULL_ORDERING = "nulls_are_large"
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
+
TIME_MAPPING = {
"AM": "%p",
"PM": "%p",
@@ -263,6 +274,7 @@ class Postgres(Dialect):
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"CHARACTER VARYING": TokenType.VARCHAR,
+ "CONSTRAINT TRIGGER": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
@@ -277,6 +289,7 @@ class Postgres(Dialect):
"TEMP": TokenType.TEMPORARY,
"CSTRING": TokenType.PSEUDO_TYPE,
"OID": TokenType.OBJECT_IDENTIFIER,
+ "OPERATOR": TokenType.OPERATOR,
"REGCLASS": TokenType.OBJECT_IDENTIFIER,
"REGCOLLATION": TokenType.OBJECT_IDENTIFIER,
"REGCONFIG": TokenType.OBJECT_IDENTIFIER,
@@ -298,8 +311,6 @@ class Postgres(Dialect):
VAR_SINGLE_TOKENS = {"$"}
class Parser(parser.Parser):
- CONCAT_NULL_OUTPUTS_STRING = True
-
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
@@ -326,12 +337,13 @@ class Postgres(Dialect):
RANGE_PARSERS = {
**parser.Parser.RANGE_PARSERS,
+ TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps),
TokenType.DAT: lambda self, this: self.expression(
exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this]
),
- TokenType.AT_GT: binary_range_parser(exp.ArrayContains),
TokenType.LT_AT: binary_range_parser(exp.ArrayContained),
+ TokenType.OPERATOR: lambda self, this: self._parse_operator(this),
}
STATEMENT_PARSERS = {
@@ -339,11 +351,28 @@ class Postgres(Dialect):
TokenType.END: lambda self: self._parse_commit_or_rollback(),
}
- def _parse_factor(self) -> t.Optional[exp.Expression]:
- return self._parse_tokens(self._parse_exponent, self.FACTOR)
+ def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
+ while True:
+ if not self._match(TokenType.L_PAREN):
+ break
+
+ op = ""
+ while self._curr and not self._match(TokenType.R_PAREN):
+ op += self._curr.text
+ self._advance()
+
+ this = self.expression(
+ exp.Operator,
+ comments=self._prev_comments,
+ this=this,
+ operator=op,
+ expression=self._parse_bitwise(),
+ )
+
+ if not self._match(TokenType.OPERATOR):
+ break
- def _parse_exponent(self) -> t.Optional[exp.Expression]:
- return self._parse_tokens(self._parse_unary, self.EXPONENT)
+ return this
def _parse_date_part(self) -> exp.Expression:
part = self._parse_type()
@@ -405,7 +434,7 @@ class Postgres(Dialect):
exp.Max: max_or_greatest,
exp.MapFromEntries: no_map_from_entries_sql,
exp.Min: min_or_least,
- exp.Merge: transforms.preprocess([_remove_target_from_merge]),
+ exp.Merge: _merge_sql,
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.PercentileCont: transforms.preprocess(
[transforms.add_within_group_for_percentiles]
@@ -434,6 +463,8 @@ class Postgres(Dialect):
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.TryCast: no_trycast_sql,
+ exp.TsOrDsAdd: _date_add_sql("+"),
+ exp.TsOrDsDiff: _date_diff_sql,
exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"),
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),