diff options
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r-- | sqlglot/dialects/postgres.py | 72 |
1 files changed, 53 insertions, 19 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index f276af1..a092cad 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import ( no_trycast_sql, str_position_sql, ) +from sqlglot.helper import seq_get from sqlglot.tokens import TokenType from sqlglot.transforms import delegate, preprocess +DATE_DIFF_FACTOR = { + "MICROSECOND": " * 1000000", + "MILLISECOND": " * 1000", + "SECOND": "", + "MINUTE": " / 60", + "HOUR": " / 3600", + "DAY": " / 86400", +} + def _date_add_sql(kind): def func(self, expression): @@ -34,16 +44,30 @@ def _date_add_sql(kind): return func -def _lateral_sql(self, expression): - this = self.sql(expression, "this") - if isinstance(expression.this, exp.Subquery): - return f"LATERAL{self.sep()}{this}" - alias = expression.args["alias"] - table = alias.name - table = f" {table}" if table else table - columns = self.expressions(alias, key="columns", flat=True) - columns = f" AS {columns}" if columns else "" - return f"LATERAL{self.sep()}{this}{table}{columns}" +def _date_diff_sql(self, expression): + unit = expression.text("unit").upper() + factor = DATE_DIFF_FACTOR.get(unit) + + end = f"CAST({expression.this} AS TIMESTAMP)" + start = f"CAST({expression.expression} AS TIMESTAMP)" + + if factor is not None: + return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)" + + age = f"AGE({end}, {start})" + + if unit == "WEEK": + extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7" + elif unit == "MONTH": + extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})" + elif unit == "QUARTER": + extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3" + elif unit == "YEAR": + extract = f"EXTRACT(year FROM {age})" + else: + self.unsupported(f"Unsupported DATEDIFF unit {unit}") + + return f"CAST({extract} AS BIGINT)" def _substring_sql(self, expression): @@ -141,7 +165,7 @@ def _serial_to_generated(expression): def _to_timestamp(args): # TO_TIMESTAMP accepts either a single double argument or (text, text) - if len(args) == 1 and args[0].is_number: + if len(args) == 1: # https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE return exp.UnixToTime.from_arg_list(args) # https://www.postgresql.org/docs/current/functions-formatting.html @@ -211,11 +235,16 @@ class Postgres(Dialect): KEYWORDS = { **tokens.Tokenizer.KEYWORDS, + "~~": TokenType.LIKE, + "~~*": TokenType.ILIKE, + "~*": TokenType.IRLIKE, + "~": TokenType.RLIKE, "ALWAYS": TokenType.ALWAYS, "BEGIN": TokenType.COMMAND, "BEGIN TRANSACTION": TokenType.BEGIN, "BIGSERIAL": TokenType.BIGSERIAL, "BY DEFAULT": TokenType.BY_DEFAULT, + "CHARACTER VARYING": TokenType.VARCHAR, "COMMENT ON": TokenType.COMMAND, "DECLARE": TokenType.COMMAND, "DO": TokenType.COMMAND, @@ -233,6 +262,7 @@ class Postgres(Dialect): "SMALLSERIAL": TokenType.SMALLSERIAL, "TEMP": TokenType.TEMPORARY, "UUID": TokenType.UUID, + "CSTRING": TokenType.PSEUDO_TYPE, **{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES}, **{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES}, } @@ -244,17 +274,16 @@ class Postgres(Dialect): class Parser(parser.Parser): STRICT_CAST = False - LATERAL_FUNCTION_AS_VIEW = True FUNCTIONS = { - **parser.Parser.FUNCTIONS, + **parser.Parser.FUNCTIONS, # type: ignore "TO_TIMESTAMP": _to_timestamp, "TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"), } class Generator(generator.Generator): TYPE_MAPPING = { - **generator.Generator.TYPE_MAPPING, + **generator.Generator.TYPE_MAPPING, # type: ignore exp.DataType.Type.TINYINT: "SMALLINT", exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", @@ -264,7 +293,7 @@ class Postgres(Dialect): } TRANSFORMS = { - **generator.Generator.TRANSFORMS, + **generator.Generator.TRANSFORMS, # type: ignore exp.ColumnDef: preprocess( [ _auto_increment_to_serial, @@ -274,13 +303,16 @@ class Postgres(Dialect): ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, - exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", - exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}", + exp.JSONBExtract: lambda self, e: self.binary(e, "#>"), + exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"), + exp.JSONBContains: lambda self, e: self.binary(e, "?"), exp.CurrentDate: no_paren_current_date_sql, exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), - exp.Lateral: _lateral_sql, + exp.DateDiff: _date_diff_sql, + exp.RegexpLike: lambda self, e: self.binary(e, "~"), + exp.RegexpILike: lambda self, e: self.binary(e, "~*"), exp.StrPosition: str_position_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", exp.Substring: _substring_sql, @@ -291,5 +323,7 @@ class Postgres(Dialect): exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})", exp.DataType: _datatype_sql, exp.GroupConcat: _string_agg_sql, - exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})" + if isinstance(seq_get(e.expressions, 0), exp.Select) + else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]", } |