summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/postgres.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/postgres.py')
-rw-r--r--sqlglot/dialects/postgres.py72
1 files changed, 53 insertions, 19 deletions
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index f276af1..a092cad 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -11,9 +11,19 @@ from sqlglot.dialects.dialect import (
no_trycast_sql,
str_position_sql,
)
+from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
from sqlglot.transforms import delegate, preprocess
+DATE_DIFF_FACTOR = {
+ "MICROSECOND": " * 1000000",
+ "MILLISECOND": " * 1000",
+ "SECOND": "",
+ "MINUTE": " / 60",
+ "HOUR": " / 3600",
+ "DAY": " / 86400",
+}
+
def _date_add_sql(kind):
def func(self, expression):
@@ -34,16 +44,30 @@ def _date_add_sql(kind):
return func
-def _lateral_sql(self, expression):
- this = self.sql(expression, "this")
- if isinstance(expression.this, exp.Subquery):
- return f"LATERAL{self.sep()}{this}"
- alias = expression.args["alias"]
- table = alias.name
- table = f" {table}" if table else table
- columns = self.expressions(alias, key="columns", flat=True)
- columns = f" AS {columns}" if columns else ""
- return f"LATERAL{self.sep()}{this}{table}{columns}"
+def _date_diff_sql(self, expression):
+ unit = expression.text("unit").upper()
+ factor = DATE_DIFF_FACTOR.get(unit)
+
+ end = f"CAST({expression.this} AS TIMESTAMP)"
+ start = f"CAST({expression.expression} AS TIMESTAMP)"
+
+ if factor is not None:
+ return f"CAST(EXTRACT(epoch FROM {end} - {start}){factor} AS BIGINT)"
+
+ age = f"AGE({end}, {start})"
+
+ if unit == "WEEK":
+ extract = f"EXTRACT(year FROM {age}) * 48 + EXTRACT(month FROM {age}) * 4 + EXTRACT(day FROM {age}) / 7"
+ elif unit == "MONTH":
+ extract = f"EXTRACT(year FROM {age}) * 12 + EXTRACT(month FROM {age})"
+ elif unit == "QUARTER":
+ extract = f"EXTRACT(year FROM {age}) * 4 + EXTRACT(month FROM {age}) / 3"
+ elif unit == "YEAR":
+ extract = f"EXTRACT(year FROM {age})"
+ else:
+ self.unsupported(f"Unsupported DATEDIFF unit {unit}")
+
+ return f"CAST({extract} AS BIGINT)"
def _substring_sql(self, expression):
@@ -141,7 +165,7 @@ def _serial_to_generated(expression):
def _to_timestamp(args):
# TO_TIMESTAMP accepts either a single double argument or (text, text)
- if len(args) == 1 and args[0].is_number:
+ if len(args) == 1:
# https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-TABLE
return exp.UnixToTime.from_arg_list(args)
# https://www.postgresql.org/docs/current/functions-formatting.html
@@ -211,11 +235,16 @@ class Postgres(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "~~": TokenType.LIKE,
+ "~~*": TokenType.ILIKE,
+ "~*": TokenType.IRLIKE,
+ "~": TokenType.RLIKE,
"ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
"BY DEFAULT": TokenType.BY_DEFAULT,
+ "CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
@@ -233,6 +262,7 @@ class Postgres(Dialect):
"SMALLSERIAL": TokenType.SMALLSERIAL,
"TEMP": TokenType.TEMPORARY,
"UUID": TokenType.UUID,
+ "CSTRING": TokenType.PSEUDO_TYPE,
**{f"CREATE {kind}": TokenType.COMMAND for kind in CREATABLES},
**{f"DROP {kind}": TokenType.COMMAND for kind in CREATABLES},
}
@@ -244,17 +274,16 @@ class Postgres(Dialect):
class Parser(parser.Parser):
STRICT_CAST = False
- LATERAL_FUNCTION_AS_VIEW = True
FUNCTIONS = {
- **parser.Parser.FUNCTIONS,
+ **parser.Parser.FUNCTIONS, # type: ignore
"TO_TIMESTAMP": _to_timestamp,
"TO_CHAR": format_time_lambda(exp.TimeToStr, "postgres"),
}
class Generator(generator.Generator):
TYPE_MAPPING = {
- **generator.Generator.TYPE_MAPPING,
+ **generator.Generator.TYPE_MAPPING, # type: ignore
exp.DataType.Type.TINYINT: "SMALLINT",
exp.DataType.Type.FLOAT: "REAL",
exp.DataType.Type.DOUBLE: "DOUBLE PRECISION",
@@ -264,7 +293,7 @@ class Postgres(Dialect):
}
TRANSFORMS = {
- **generator.Generator.TRANSFORMS,
+ **generator.Generator.TRANSFORMS, # type: ignore
exp.ColumnDef: preprocess(
[
_auto_increment_to_serial,
@@ -274,13 +303,16 @@ class Postgres(Dialect):
),
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
- exp.JSONBExtract: lambda self, e: f"{self.sql(e, 'this')}#>{self.sql(e, 'path')}",
- exp.JSONBExtractScalar: lambda self, e: f"{self.sql(e, 'this')}#>>{self.sql(e, 'path')}",
+ exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
+ exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
+ exp.JSONBContains: lambda self, e: self.binary(e, "?"),
exp.CurrentDate: no_paren_current_date_sql,
exp.CurrentTimestamp: lambda *_: "CURRENT_TIMESTAMP",
exp.DateAdd: _date_add_sql("+"),
exp.DateSub: _date_add_sql("-"),
- exp.Lateral: _lateral_sql,
+ exp.DateDiff: _date_diff_sql,
+ exp.RegexpLike: lambda self, e: self.binary(e, "~"),
+ exp.RegexpILike: lambda self, e: self.binary(e, "~*"),
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
@@ -291,5 +323,7 @@ class Postgres(Dialect):
exp.UnixToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')})",
exp.DataType: _datatype_sql,
exp.GroupConcat: _string_agg_sql,
- exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
+ exp.Array: lambda self, e: f"{self.normalize_func('ARRAY')}({self.sql(e.expressions[0])})"
+ if isinstance(seq_get(e.expressions, 0), exp.Select)
+ else f"{self.normalize_func('ARRAY')}[{self.expressions(e, flat=True)}]",
}