diff options
author | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
---|---|---|
committer | Daniel Baumann <mail@daniel-baumann.ch> | 2023-12-10 10:46:01 +0000 |
commit | 8fe30fd23dc37ec3516e530a86d1c4b604e71241 (patch) | |
tree | 6e2ebbf565b0351fd0f003f488a8339e771ad90c /sqlglot/dialects/postgres.py | |
parent | Releasing debian version 19.0.1-1. (diff) | |
download | sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.tar.xz sqlglot-8fe30fd23dc37ec3516e530a86d1c4b604e71241.zip |
Merging upstream version 20.1.0.
Signed-off-by: Daniel Baumann <mail@daniel-baumann.ch>
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r-- | sqlglot/dialects/postgres.py | 97 |
1 files changed, 64 insertions, 33 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 27c6851..fefddee 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -4,6 +4,7 @@ import typing as t from sqlglot import exp, generator, parser, tokens, transforms from sqlglot.dialects.dialect import ( + DATE_ADD_OR_SUB, Dialect, any_value_to_max_sql, arrow_json_extract_scalar_sql, @@ -25,6 +26,7 @@ from sqlglot.dialects.dialect import ( timestamptrunc_sql, timestrtotime_sql, trim_sql, + ts_or_ds_add_cast, ts_or_ds_to_date_sql, ) from sqlglot.helper import seq_get @@ -41,8 +43,11 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, exp.DateAdd | exp.DateSub], str]: - def func(self: Postgres.Generator, expression: exp.DateAdd | exp.DateSub) -> str: +def _date_add_sql(kind: str) -> t.Callable[[Postgres.Generator, DATE_ADD_OR_SUB], str]: + def func(self: Postgres.Generator, expression: DATE_ADD_OR_SUB) -> str: + if isinstance(expression, exp.TsOrDsAdd): + expression = ts_or_ds_add_cast(expression) + this = self.sql(expression, "this") unit = expression.args.get("unit") @@ -60,8 +65,8 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) - end = f"CAST({expression.this} AS TIMESTAMP)" - start = f"CAST({expression.expression} AS TIMESTAMP)" + end = f"CAST({self.sql(expression, 'this')} AS TIMESTAMP)" + start = f"CAST({self.sql(expression, 'expression')} AS TIMESTAMP)" if factor is not None: return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)" @@ -69,7 +74,7 @@ def _date_diff_sql(self: Postgres.Generator, expression: exp.DateDiff) -> str: age = f"AGE({end}, {start})" if unit == "WEEK": - unit = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + unit = f"EXTRACT(days FROM ({end} - {start})) / 7" elif unit == "MONTH": unit = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" elif unit == "QUARTER": @@ -183,37 +188,43 @@ def _to_timestamp(args: t.List) -> exp.Expression: return format_time_lambda(exp.StrToTime, "postgres")(args) -def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: - """Remove table refs from columns in when statements.""" - if isinstance(expression, exp.Merge): - alias = expression.this.args.get("alias") +def _merge_sql(self: Postgres.Generator, expression: exp.Merge) -> str: + def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression: + """Remove table refs from columns in when statements.""" + if isinstance(expression, exp.Merge): + alias = expression.this.args.get("alias") - normalize = ( - lambda identifier: Postgres.normalize_identifier(identifier).name - if identifier - else None - ) + normalize = ( + lambda identifier: self.dialect.normalize_identifier(identifier).name + if identifier + else None + ) - targets = {normalize(expression.this.this)} + targets = {normalize(expression.this.this)} - if alias: - targets.add(normalize(alias.this)) + if alias: + targets.add(normalize(alias.this)) - for when in expression.expressions: - when.transform( - lambda node: exp.column(node.this) - if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets - else node, - copy=False, - ) + for when in expression.expressions: + when.transform( + lambda node: exp.column(node.this) + if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets + else node, + copy=False, + ) - return expression + return expression + + return transforms.preprocess([_remove_target_from_merge])(self, expression) class Postgres(Dialect): INDEX_OFFSET = 1 + TYPED_DIVISION = True + CONCAT_COALESCE = True NULL_ORDERING = "nulls_are_large" TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'" + TIME_MAPPING = { "AM": "%p", "PM": "%p", @@ -263,6 +274,7 @@ class Postgres(Dialect): "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "CHARACTER VARYING": TokenType.VARCHAR, + "CONSTRAINT TRIGGER": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, "HSTORE": TokenType.HSTORE, @@ -277,6 +289,7 @@ class Postgres(Dialect): "TEMP": TokenType.TEMPORARY, "CSTRING": TokenType.PSEUDO_TYPE, "OID": TokenType.OBJECT_IDENTIFIER, + "OPERATOR": TokenType.OPERATOR, "REGCLASS": TokenType.OBJECT_IDENTIFIER, "REGCOLLATION": TokenType.OBJECT_IDENTIFIER, "REGCONFIG": TokenType.OBJECT_IDENTIFIER, @@ -298,8 +311,6 @@ class Postgres(Dialect): VAR_SINGLE_TOKENS = {"$"} class Parser(parser.Parser): - CONCAT_NULL_OUTPUTS_STRING = True - FUNCTIONS = { **parser.Parser.FUNCTIONS, "DATE_TRUNC": parse_timestamp_trunc, @@ -326,12 +337,13 @@ class Postgres(Dialect): RANGE_PARSERS = { **parser.Parser.RANGE_PARSERS, + TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), TokenType.DAT: lambda self, this: self.expression( exp.MatchAgainst, this=self._parse_bitwise(), expressions=[this] ), - TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), + TokenType.OPERATOR: lambda self, this: self._parse_operator(this), } STATEMENT_PARSERS = { @@ -339,11 +351,28 @@ class Postgres(Dialect): TokenType.END: lambda self: self._parse_commit_or_rollback(), } - def _parse_factor(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_exponent, self.FACTOR) + def _parse_operator(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: + while True: + if not self._match(TokenType.L_PAREN): + break + + op = "" + while self._curr and not self._match(TokenType.R_PAREN): + op += self._curr.text + self._advance() + + this = self.expression( + exp.Operator, + comments=self._prev_comments, + this=this, + operator=op, + expression=self._parse_bitwise(), + ) + + if not self._match(TokenType.OPERATOR): + break - def _parse_exponent(self) -> t.Optional[exp.Expression]: - return self._parse_tokens(self._parse_unary, self.EXPONENT) + return this def _parse_date_part(self) -> exp.Expression: part = self._parse_type() @@ -405,7 +434,7 @@ class Postgres(Dialect): exp.Max: max_or_greatest, exp.MapFromEntries: no_map_from_entries_sql, exp.Min: min_or_least, - exp.Merge: transforms.preprocess([_remove_target_from_merge]), + exp.Merge: _merge_sql, exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}", exp.PercentileCont: transforms.preprocess( [transforms.add_within_group_for_percentiles] @@ -434,6 +463,8 @@ class Postgres(Dialect): exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: trim_sql, exp.TryCast: no_trycast_sql, + exp.TsOrDsAdd: _date_add_sql("+"), + exp.TsOrDsDiff: _date_diff_sql, exp.TsOrDsToDate: ts_or_ds_to_date_sql("postgres"), exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.VariancePop: rename_func("VAR_POP"), |