From 90150543f9314be683d22a16339effd774192f6d Mon Sep 17 00:00:00 2001 From: Daniel Baumann Date: Thu, 22 Sep 2022 06:31:28 +0200 Subject: Merging upstream version 6.1.1. Signed-off-by: Daniel Baumann --- sqlglot/dialects/postgres.py | 116 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 110 insertions(+), 6 deletions(-) (limited to 'sqlglot/dialects/postgres.py') diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 61dff86..c796839 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import ( from sqlglot.generator import Generator from sqlglot.parser import Parser from sqlglot.tokens import Tokenizer, TokenType +from sqlglot.transforms import delegate, preprocess def _date_add_sql(kind): @@ -32,11 +33,96 @@ def _date_add_sql(kind): return func +def _lateral_sql(self, expression): + this = self.sql(expression, "this") + if isinstance(expression.this, exp.Subquery): + return f"LATERAL{self.sep()}{this}" + alias = expression.args["alias"] + table = alias.name + table = f" {table}" if table else table + columns = self.expressions(alias, key="columns", flat=True) + columns = f" AS {columns}" if columns else "" + return f"LATERAL{self.sep()}{this}{table}{columns}" + + +def _substring_sql(self, expression): + this = self.sql(expression, "this") + start = self.sql(expression, "start") + length = self.sql(expression, "length") + + from_part = f" FROM {start}" if start else "" + for_part = f" FOR {length}" if length else "" + + return f"SUBSTRING({this}{from_part}{for_part})" + + +def _trim_sql(self, expression): + target = self.sql(expression, "this") + trim_type = self.sql(expression, "position") + remove_chars = self.sql(expression, "expression") + collation = self.sql(expression, "collation") + + # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific + if not remove_chars and not collation: + return self.trim_sql(expression) + + trim_type = f"{trim_type} " if trim_type else "" + remove_chars = f"{remove_chars} " if remove_chars else "" + from_part = "FROM " if trim_type or remove_chars else "" + collation = f" COLLATE {collation}" if collation else "" + return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" + + +def _auto_increment_to_serial(expression): + auto = expression.find(exp.AutoIncrementColumnConstraint) + + if auto: + expression = expression.copy() + expression.args["constraints"].remove(auto.parent) + kind = expression.args["kind"] + + if kind.this == exp.DataType.Type.INT: + kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL)) + elif kind.this == exp.DataType.Type.SMALLINT: + kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL)) + elif kind.this == exp.DataType.Type.BIGINT: + kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL)) + + return expression + + +def _serial_to_generated(expression): + kind = expression.args["kind"] + + if kind.this == exp.DataType.Type.SERIAL: + data_type = exp.DataType(this=exp.DataType.Type.INT) + elif kind.this == exp.DataType.Type.SMALLSERIAL: + data_type = exp.DataType(this=exp.DataType.Type.SMALLINT) + elif kind.this == exp.DataType.Type.BIGSERIAL: + data_type = exp.DataType(this=exp.DataType.Type.BIGINT) + else: + data_type = None + + if data_type: + expression = expression.copy() + expression.args["kind"].replace(data_type) + constraints = expression.args["constraints"] + generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False)) + notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint()) + if notnull not in constraints: + constraints.insert(0, notnull) + if generated not in constraints: + constraints.insert(0, generated) + + return expression + + class Postgres(Dialect): null_ordering = "nulls_are_large" time_format = "'YYYY-MM-DD HH24:MI:SS'" time_mapping = { - "AM": "%p", # AM or PM + "AM": "%p", + "PM": "%p", "D": "%w", # 1-based day of week "DD": "%d", # day of month "DDD": "%j", # zero padded day of year @@ -65,14 +151,25 @@ class Postgres(Dialect): } class Tokenizer(Tokenizer): + BIT_STRINGS = [("b'", "'"), ("B'", "'")] + HEX_STRINGS = [("x'", "'"), ("X'", "'")] KEYWORDS = { **Tokenizer.KEYWORDS, - "SERIAL": TokenType.AUTO_INCREMENT, + "ALWAYS": TokenType.ALWAYS, + "BY DEFAULT": TokenType.BY_DEFAULT, + "IDENTITY": TokenType.IDENTITY, + "FOR": TokenType.FOR, + "GENERATED": TokenType.GENERATED, + "DOUBLE PRECISION": TokenType.DOUBLE, + "BIGSERIAL": TokenType.BIGSERIAL, + "SERIAL": TokenType.SERIAL, + "SMALLSERIAL": TokenType.SMALLSERIAL, "UUID": TokenType.UUID, } class Parser(Parser): STRICT_CAST = False + FUNCTIONS = { **Parser.FUNCTIONS, "TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"), @@ -86,14 +183,18 @@ class Postgres(Dialect): exp.DataType.Type.FLOAT: "REAL", exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.BINARY: "BYTEA", - } - - TOKEN_MAPPING = { - TokenType.AUTO_INCREMENT: "SERIAL", + exp.DataType.Type.DATETIME: "TIMESTAMP", } TRANSFORMS = { **Generator.TRANSFORMS, + exp.ColumnDef: preprocess( + [ + _auto_increment_to_serial, + _serial_to_generated, + ], + delegate("columndef_sql"), + ), exp.JSONExtract: arrow_json_extract_sql, exp.JSONExtractScalar: arrow_json_extract_scalar_sql, exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}", @@ -102,8 +203,11 @@ class Postgres(Dialect): exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP", exp.DateAdd: _date_add_sql("+"), exp.DateSub: _date_add_sql("-"), + exp.Lateral: _lateral_sql, exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})", + exp.Substring: _substring_sql, exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})", exp.TableSample: no_tablesample_sql, + exp.Trim: _trim_sql, exp.TryCast: no_trycast_sql, } -- cgit v1.2.3