summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r--sqlglot/dialects/postgres.py116
1 files changed, 110 insertions, 6 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index 61dff86..c796839 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,6 +11,7 @@ from sqlglot.dialects.dialect import (
from sqlglot.generator import Generator
from sqlglot.parser import Parser
from sqlglot.tokens import Tokenizer, TokenType
+from sqlglot.transforms import delegate, preprocess
def _date_add_sql(kind):
@@ -32,11 +33,96 @@ def _date_add_sql(kind):
return func
+def _lateral_sql(self, expression):
+ this = self.sql(expression, "this")
+ if isinstance(expression.this, exp.Subquery):
+ return f"LATERAL{self.sep()}{this}"
+ alias = expression.args["alias"]
+ table = alias.name
+ table = f" {table}" if table else table
+ columns = self.expressions(alias, key="columns", flat=True)
+ columns = f" AS {columns}" if columns else ""
+ return f"LATERAL{self.sep()}{this}{table}{columns}"
+
+
+def _substring_sql(self, expression):
+ this = self.sql(expression, "this")
+ start = self.sql(expression, "start")
+ length = self.sql(expression, "length")
+
+ from_part = f" FROM {start}" if start else ""
+ for_part = f" FOR {length}" if length else ""
+
+ return f"SUBSTRING({this}{from_part}{for_part})"
+
+
+def _trim_sql(self, expression):
+ target = self.sql(expression, "this")
+ trim_type = self.sql(expression, "position")
+ remove_chars = self.sql(expression, "expression")
+ collation = self.sql(expression, "collation")
+
+ # Use TRIM/LTRIM/RTRIM syntax if the expression isn't postgres-specific
+ if not remove_chars and not collation:
+ return self.trim_sql(expression)
+
+ trim_type = f"{trim_type} " if trim_type else ""
+ remove_chars = f"{remove_chars} " if remove_chars else ""
+ from_part = "FROM " if trim_type or remove_chars else ""
+ collation = f" COLLATE {collation}" if collation else ""
+ return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
+
+
+def _auto_increment_to_serial(expression):
+ auto = expression.find(exp.AutoIncrementColumnConstraint)
+
+ if auto:
+ expression = expression.copy()
+ expression.args["constraints"].remove(auto.parent)
+ kind = expression.args["kind"]
+
+ if kind.this == exp.DataType.Type.INT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.SERIAL))
+ elif kind.this == exp.DataType.Type.SMALLINT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.SMALLSERIAL))
+ elif kind.this == exp.DataType.Type.BIGINT:
+ kind.replace(exp.DataType(this=exp.DataType.Type.BIGSERIAL))
+
+ return expression
+
+
+def _serial_to_generated(expression):
+ kind = expression.args["kind"]
+
+ if kind.this == exp.DataType.Type.SERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.INT)
+ elif kind.this == exp.DataType.Type.SMALLSERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.SMALLINT)
+ elif kind.this == exp.DataType.Type.BIGSERIAL:
+ data_type = exp.DataType(this=exp.DataType.Type.BIGINT)
+ else:
+ data_type = None
+
+ if data_type:
+ expression = expression.copy()
+ expression.args["kind"].replace(data_type)
+ constraints = expression.args["constraints"]
+ generated = exp.ColumnConstraint(kind=exp.GeneratedAsIdentityColumnConstraint(this=False))
+ notnull = exp.ColumnConstraint(kind=exp.NotNullColumnConstraint())
+ if notnull not in constraints:
+ constraints.insert(0, notnull)
+ if generated not in constraints:
+ constraints.insert(0, generated)
+
+ return expression
+
+
class Postgres(Dialect):
null_ordering = "nulls_are_large"
time_format = "'YYYY-MM-DD HH24:MI:SS'"
time_mapping = {
- "AM": "%p", # AM or PM
+ "AM": "%p",
+ "PM": "%p",
"D": "%w", # 1-based day of week
"DD": "%d", # day of month
"DDD": "%j", # zero padded day of year
@@ -65,14 +151,25 @@ class Postgres(Dialect):
}
class Tokenizer(Tokenizer):
+ BIT_STRINGS = [("b'", "'"), ("B'", "'")]
+ HEX_STRINGS = [("x'", "'"), ("X'", "'")]
KEYWORDS = {
**Tokenizer.KEYWORDS,
- "SERIAL": TokenType.AUTO_INCREMENT,
+ "ALWAYS": TokenType.ALWAYS,
+ "BY DEFAULT": TokenType.BY_DEFAULT,
+ "IDENTITY": TokenType.IDENTITY,
+ "FOR": TokenType.FOR,
+ "GENERATED": TokenType.GENERATED,
+ "DOUBLE PRECISION": TokenType.DOUBLE,
+ "BIGSERIAL": TokenType.BIGSERIAL,
+ "SERIAL": TokenType.SERIAL,
+ "SMALLSERIAL": TokenType.SMALLSERIAL,
"UUID": TokenType.UUID,
}
class Parser(Parser):
STRICT_CAST = False
+
FUNCTIONS = {
**Parser.FUNCTIONS,
"TO_TIMESTAMP": format_time_lambda(exp.StrToTime, "postgres"),
@@ -86,14 +183,18 @@ class Postgres(Dialect):
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
exp.DataType.Type.BINARY: "BYTEA",
- }
-
- TOKEN_MAPPING = {
- TokenType.AUTO_INCREMENT: "SERIAL",
+ exp.DataType.Type.DATETIME: "TIMESTAMP",
}
TRANSFORMS = {
**Generator.TRANSFORMS,
+ exp.ColumnDef: preprocess(
+ [
+ _auto_increment_to_serial,
+ _serial_to_generated,
+ ],
+ delegate("columndef_sql"),
+ ),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
@@ -102,8 +203,11 @@ class Postgres(Dialect):
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
+ exp.Lateral: _lateral_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
+ exp.Substring: _substring_sql,
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
+ exp.Trim: _trim_sql,
exp.TryCast: no_trycast_sql,
}