diff options
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r-- | sqlglot/dialects/postgres.py | 45 |
1 files changed, 22 insertions, 23 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 2132778..ab61880 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( max_or_greatest, min_or_least, no_paren_current_date_sql, + no_pivot_sql, no_tablesample_sql, no_trycast_sql, rename_func, @@ -33,8 +34,8 @@ DATE_DIFF_FACTOR = { } -def _date_add_sql(kind): - def func(self, expression): +def _date_add_sql(kind: str) -> t.Callable[[generator.Generator, exp.DateAdd | exp.DateSub], str]: + def func(self: generator.Generator, expression: exp.DateAdd | exp.DateSub) -> str: from sqlglot.optimizer.simplify import simplify this = self.sql(expression, "this") @@ -51,7 +52,7 @@ def _date_add_sql(kind): return func -def _date_diff_sql(self, expression): +def _date_diff_sql(self: generator.Generator, expression: exp.DateDiff) -> str: unit = expression.text("unit").upper() factor = DATE_DIFF_FACTOR.get(unit) @@ -77,7 +78,7 @@ def _date_diff_sql(self, expression): return f"CAST({unit} AS BIGINT)" -def _substring_sql(self, expression): +def _substring_sql(self: generator.Generator, expression: exp.Substring) -> str: this = self.sql(expression, "this") start = self.sql(expression, "start") length = self.sql(expression, "length") @@ -88,7 +89,7 @@ def _substring_sql(self, expression): return f"SUBSTRING({this}{from_part}{for_part})" -def _string_agg_sql(self, expression): +def _string_agg_sql(self: generator.Generator, expression: exp.GroupConcat) -> str: expression = expression.copy() separator = expression.args.get("separator") or exp.Literal.string(",") @@ -102,13 +103,13 @@ def _string_agg_sql(self, expression): return f"STRING_AGG({self.format_args(this, separator)}{order})" -def _datatype_sql(self, expression): +def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: if expression.this == exp.DataType.Type.ARRAY: return f"{self.expressions(expression, flat=True)}[]" return self.datatype_sql(expression) -def _auto_increment_to_serial(expression): +def _auto_increment_to_serial(expression: exp.Expression) -> exp.Expression: auto = expression.find(exp.AutoIncrementColumnConstraint) if auto: @@ -126,7 +127,7 @@ def _auto_increment_to_serial(expression): return expression -def _serial_to_generated(expression): +def _serial_to_generated(expression: exp.Expression) -> exp.Expression: kind = expression.args["kind"] if kind.this == exp.DataType.Type.SERIAL: @@ -144,6 +145,7 @@ def _serial_to_generated(expression): constraints = expression.args["constraints"] generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) + if notnull not in constraints: constraints.insert(0, notnull) if generated not in constraints: @@ -152,7 +154,7 @@ def _serial_to_generated(expression): return expression -def _generate_series(args): +def _generate_series(args: t.List) -> exp.Expression: # The goal is to convert step values like '1 day' or INTERVAL '1 day' into INTERVAL '1' day step = seq_get(args, 2) @@ -168,11 +170,12 @@ def _generate_series(args): return exp.GenerateSeries.from_arg_list(args) -def _to_timestamp(args): +def _to_timestamp(args: t.List) -> exp.Expression: # TO_TIMESTAMP accepts either a single double argument or (text, text) if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) + # https://www.postgresql.org/docs/current/functions-formatting.html return format_time_lambda(exp.StrToTime, "postgres")(args) @@ -255,7 +258,7 @@ class Postgres(Dialect): STRICT_CAST = False FUNCTIONS = { - **parser.Parser.FUNCTIONS, # type: ignore + **parser.Parser.FUNCTIONS, "DATE_TRUNC": lambda args: exp.TimestampTrunc( this=seq_get(args, 1), unit=seq_get(args, 0) ), @@ -271,7 +274,7 @@ class Postgres(Dialect): } BITWISE = { - **parser.Parser.BITWISE, # type: ignore + **parser.Parser.BITWISE, TokenType.HASH: exp.BitwiseXor, } @@ -280,7 +283,7 @@ class Postgres(Dialect): } RANGE_PARSERS = { - **parser.Parser.RANGE_PARSERS, # type: ignore + **parser.Parser.RANGE_PARSERS, TokenType.DAMP: binary_range_parser(exp.ArrayOverlaps), TokenType.AT_GT: binary_range_parser(exp.ArrayContains), TokenType.LT_AT: binary_range_parser(exp.ArrayContained), @@ -303,14 +306,14 @@ class Postgres(Dialect): return self.expression(exp.Extract, this=part, expression=value) class Generator(generator.Generator): - INTERVAL_ALLOWS_PLURAL_FORM = False + SINGLE_STRING_INTERVAL = True LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False PARAMETER_TOKEN = "$" TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, # type: ignore + **generator.Generator.TYPE_MAPPING, exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", @@ -320,14 +323,9 @@ class Postgres(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, # type: ignore + **generator.Generator.TRANSFORMS, exp.BitwiseXor: lambda self, e: self.binary(e, "#"), - exp.ColumnDef: transforms.preprocess( - [ - _auto_increment_to_serial, - _serial_to_generated, - ], - ), + exp.ColumnDef: transforms.preprocess([_auto_increment_to_serial, _serial_to_generated]), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), @@ -348,6 +346,7 @@ class Postgres(Dialect): exp.ArrayContains: lambda self, e: self.binary(e, "@>"), exp.ArrayContained: lambda self, e: self.binary(e, "<@"), exp.Merge: transforms.preprocess([transforms.remove_target_from_merge]), + exp.Pivot: no_pivot_sql, exp.RegexpLike: lambda self, e: self.binary(e, "~"), exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, @@ -369,7 +368,7 @@ class Postgres(Dialect): } PROPERTIES_LOCATION = { - **generator.Generator.PROPERTIES_LOCATION, # type: ignore + **generator.Generator.PROPERTIES_LOCATION, exp.TransientProperty: exp.Properties.Location.UNSUPPORTED, exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } |