summaryrefslogtreecommitdiffstats
path: root/sqlglot
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-12 08:28:54 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-09-12 08:28:54 +0000
commit7db33518a4264e422294a1e20fbd1c1505d9a62d (patch)
treeaeb9ae54563b1f8f9c26fd54d0c207b082b89cd4 /sqlglot
parentReleasing debian version 18.2.0-1. (diff)
downloadsqlglot-7db33518a4264e422294a1e20fbd1c1505d9a62d.tar.xz
sqlglot-7db33518a4264e422294a1e20fbd1c1505d9a62d.zip
Merging upstream version 18.3.0.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'sqlglot')
-rw-r--r--sqlglot/dialects/doris.py2
-rw-r--r--sqlglot/dialects/mysql.py36
-rw-r--r--sqlglot/dialects/postgres.py6
-rw-r--r--sqlglot/dialects/spark.py7
-rw-r--r--sqlglot/dialects/teradata.py8
-rw-r--r--sqlglot/dialects/tsql.py14
-rw-r--r--sqlglot/expressions.py15
-rw-r--r--sqlglot/generator.py8
-rw-r--r--sqlglot/parser.py43
-rw-r--r--sqlglot/tokens.py1
-rw-r--r--sqlglot/transforms.py2
11 files changed, 117 insertions, 25 deletions
diff --git a/sqlglot/dialects/doris.py b/sqlglot/dialects/doris.py
index 4b8919c..bd7e0f2 100644
--- a/sqlglot/dialects/doris.py
+++ b/sqlglot/dialects/doris.py
@@ -33,6 +33,8 @@ class Doris(MySQL):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}
+ TIMESTAMP_FUNC_TYPES = set()
+
TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index f9249eb..6327796 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -555,7 +555,26 @@ class MySQL(Dialect):
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
}
- TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
+ UNSIGNED_TYPE_MAPPING = {
+ exp.DataType.Type.UBIGINT: "BIGINT",
+ exp.DataType.Type.UINT: "INT",
+ exp.DataType.Type.UMEDIUMINT: "MEDIUMINT",
+ exp.DataType.Type.USMALLINT: "SMALLINT",
+ exp.DataType.Type.UTINYINT: "TINYINT",
+ }
+
+ TIMESTAMP_TYPE_MAPPING = {
+ exp.DataType.Type.TIMESTAMP: "DATETIME",
+ exp.DataType.Type.TIMESTAMPTZ: "TIMESTAMP",
+ exp.DataType.Type.TIMESTAMPLTZ: "TIMESTAMP",
+ }
+
+ TYPE_MAPPING = {
+ **generator.Generator.TYPE_MAPPING,
+ **UNSIGNED_TYPE_MAPPING,
+ **TIMESTAMP_TYPE_MAPPING,
+ }
+
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMBLOB)
@@ -580,6 +599,18 @@ class MySQL(Dialect):
exp.DataType.Type.VARCHAR: "CHAR",
}
+ TIMESTAMP_FUNC_TYPES = {
+ exp.DataType.Type.TIMESTAMPTZ,
+ exp.DataType.Type.TIMESTAMPLTZ,
+ }
+
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ # https://dev.mysql.com/doc/refman/8.0/en/numeric-type-syntax.html
+ result = super().datatype_sql(expression)
+ if expression.this in self.UNSIGNED_TYPE_MAPPING:
+ result = f"{result} UNSIGNED"
+ return result
+
def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:
# MySQL requires simple literal values for its LIMIT clause.
expression = simplify_literal(expression.copy())
@@ -599,6 +630,9 @@ class MySQL(Dialect):
return f"{self.sql(expression, 'this')} MEMBER OF({self.sql(expression, 'expression')})"
def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
+ if expression.to.this in self.TIMESTAMP_FUNC_TYPES:
+ return self.func("TIMESTAMP", expression.this)
+
to = self.CAST_MAPPING.get(expression.to.this)
if to:
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index c26e121..5027013 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -190,7 +190,11 @@ def _remove_target_from_merge(expression: exp.Expression) -> exp.Expression:
if isinstance(expression, exp.Merge):
alias = expression.this.args.get("alias")
- normalize = lambda identifier: Postgres.normalize_identifier(identifier).name
+ normalize = (
+ lambda identifier: Postgres.normalize_identifier(identifier).name
+ if identifier
+ else None
+ )
targets = {normalize(expression.this.this)}
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index a4435f6..9d4a1ab 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -35,6 +35,13 @@ def _parse_datediff(args: t.List) -> exp.Expression:
class Spark(Spark2):
+ class Tokenizer(Spark2.Tokenizer):
+ RAW_STRINGS = [
+ (prefix + q, q)
+ for q in t.cast(t.List[str], Spark2.Tokenizer.QUOTES)
+ for prefix in ("r", "R")
+ ]
+
class Parser(Spark2.Parser):
FUNCTIONS = {
**Spark2.Parser.FUNCTIONS,
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
index 163cc13..d9de968 100644
--- a/sqlglot/dialects/teradata.py
+++ b/sqlglot/dialects/teradata.py
@@ -45,6 +45,7 @@ class Teradata(Dialect):
"MOD": TokenType.MOD,
"NE": TokenType.NEQ,
"NOT=": TokenType.NEQ,
+ "SAMPLE": TokenType.TABLE_SAMPLE,
"SEL": TokenType.SELECT,
"ST_GEOMETRY": TokenType.GEOMETRY,
"TOP": TokenType.TOP,
@@ -55,6 +56,8 @@ class Teradata(Dialect):
SINGLE_TOKENS.pop("%")
class Parser(parser.Parser):
+ TABLESAMPLE_CSV = True
+
CHARSET_TRANSLATORS = {
"GRAPHIC_TO_KANJISJIS",
"GRAPHIC_TO_LATIN",
@@ -171,6 +174,11 @@ class Teradata(Dialect):
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
}
+ def tablesample_sql(
+ self, expression: exp.TableSample, seed_prefix: str = "SEED", sep=" AS "
+ ) -> str:
+ return f"{self.sql(expression, 'this')} SAMPLE {self.expressions(expression)}"
+
def partitionedbyproperty_sql(self, expression: exp.PartitionedByProperty) -> str:
return f"PARTITION BY {self.sql(expression, 'this')}"
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b26f499..19c586e 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -57,6 +57,8 @@ TRANSPILE_SAFE_NUMBER_FMT = {"N", "C"}
DEFAULT_START_DATE = datetime.date(1900, 1, 1)
+BIT_TYPES = {exp.EQ, exp.NEQ, exp.Is, exp.In, exp.Select, exp.Alias}
+
def _format_time_lambda(
exp_class: t.Type[E], full_format_mapping: t.Optional[bool] = None
@@ -584,6 +586,7 @@ class TSQL(Dialect):
RETURNING_END = False
NVL2_SUPPORTED = False
ALTER_TABLE_ADD_COLUMN_KEYWORD = False
+ LIMIT_FETCH = "FETCH"
TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
@@ -630,7 +633,16 @@ class TSQL(Dialect):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}
- LIMIT_FETCH = "FETCH"
+ def boolean_sql(self, expression: exp.Boolean) -> str:
+ if type(expression.parent) in BIT_TYPES:
+ return "1" if expression.this else "0"
+
+ return "(1 = 1)" if expression.this else "(1 = 0)"
+
+ def is_sql(self, expression: exp.Is) -> str:
+ if isinstance(expression.expression, exp.Boolean):
+ return self.binary(expression, "=")
+ return self.binary(expression, "IS")
def createable_sql(self, expression: exp.Create, locations: t.DefaultDict) -> str:
sql = self.sql(expression, "this")
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index 0479da0..877e9fd 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -3350,6 +3350,7 @@ class Subquery(DerivedTable, Unionable):
class TableSample(Expression):
arg_types = {
"this": False,
+ "expressions": False,
"method": False,
"bucket_numerator": False,
"bucket_denominator": False,
@@ -3542,6 +3543,7 @@ class DataType(Expression):
UINT = auto()
UINT128 = auto()
UINT256 = auto()
+ UMEDIUMINT = auto()
UNIQUEIDENTIFIER = auto()
UNKNOWN = auto() # Sentinel value, useful for type annotation
USERDEFINED = "USER-DEFINED"
@@ -3708,7 +3710,7 @@ class Rollback(Expression):
class AlterTable(Expression):
- arg_types = {"this": True, "actions": True, "exists": False}
+ arg_types = {"this": True, "actions": True, "exists": False, "only": False}
class AddConstraint(Expression):
@@ -3993,15 +3995,10 @@ class TimeUnit(Expression):
# https://www.oracletutorial.com/oracle-basics/oracle-interval/
-# https://trino.io/docs/current/language/types.html#interval-year-to-month
-class IntervalYearToMonthSpan(Expression):
- arg_types = {}
-
-
-# https://www.oracletutorial.com/oracle-basics/oracle-interval/
# https://trino.io/docs/current/language/types.html#interval-day-to-second
-class IntervalDayToSecondSpan(Expression):
- arg_types = {}
+# https://docs.databricks.com/en/sql/language-manual/data-types/interval-type.html
+class IntervalSpan(Expression):
+ arg_types = {"this": True, "expression": True}
class Interval(TimeUnit):
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 306df81..1074e9a 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -72,8 +72,7 @@ class Generator:
exp.ExternalProperty: lambda self, e: "EXTERNAL",
exp.HeapProperty: lambda self, e: "HEAP",
exp.InlineLengthColumnConstraint: lambda self, e: f"INLINE LENGTH {self.sql(e, 'this')}",
- exp.IntervalDayToSecondSpan: "DAY TO SECOND",
- exp.IntervalYearToMonthSpan: "YEAR TO MONTH",
+ exp.IntervalSpan: lambda self, e: f"{self.sql(e, 'this')} TO {self.sql(e, 'expression')}",
exp.LanguageProperty: lambda self, e: self.naked_property(e),
exp.LocationProperty: lambda self, e: self.naked_property(e),
exp.LogProperty: lambda self, e: f"{'NO ' if e.args.get('no') else ''}LOG",
@@ -953,7 +952,7 @@ class Generator:
def filter_sql(self, expression: exp.Filter) -> str:
this = self.sql(expression, "this")
- where = self.sql(expression, "expression")[1:] # where has a leading space
+ where = self.sql(expression, "expression").strip()
return f"{this} FILTER({where})"
def hint_sql(self, expression: exp.Hint) -> str:
@@ -2290,7 +2289,8 @@ class Generator:
actions = self.expressions(expression, key="actions")
exists = " IF EXISTS" if expression.args.get("exists") else ""
- return f"ALTER TABLE{exists} {self.sql(expression, 'this')} {actions}"
+ only = " ONLY" if expression.args.get("only") else ""
+ return f"ALTER TABLE{exists}{only} {self.sql(expression, 'this')} {actions}"
def droppartition_sql(self, expression: exp.DropPartition) -> str:
expressions = self.expressions(expression)
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index f8690d5..939303f 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -137,6 +137,7 @@ class Parser(metaclass=_Parser):
TokenType.INT256,
TokenType.UINT256,
TokenType.MEDIUMINT,
+ TokenType.UMEDIUMINT,
TokenType.FIXEDSTRING,
TokenType.FLOAT,
TokenType.DOUBLE,
@@ -206,6 +207,14 @@ class Parser(metaclass=_Parser):
*NESTED_TYPE_TOKENS,
}
+ SIGNED_TO_UNSIGNED_TYPE_TOKEN = {
+ TokenType.BIGINT: TokenType.UBIGINT,
+ TokenType.INT: TokenType.UINT,
+ TokenType.MEDIUMINT: TokenType.UMEDIUMINT,
+ TokenType.SMALLINT: TokenType.USMALLINT,
+ TokenType.TINYINT: TokenType.UTINYINT,
+ }
+
SUBQUERY_PREDICATES = {
TokenType.ANY: exp.Any,
TokenType.ALL: exp.All,
@@ -856,6 +865,9 @@ class Parser(metaclass=_Parser):
# Whether or not ADD is present for each column added by ALTER TABLE
ALTER_TABLE_ADD_COLUMN_KEYWORD = True
+ # Whether or not the table sample clause expects CSV syntax
+ TABLESAMPLE_CSV = False
+
__slots__ = (
"error_level",
"error_message_context",
@@ -2672,7 +2684,12 @@ class Parser(metaclass=_Parser):
self._match(TokenType.L_PAREN)
- num = self._parse_number()
+ if self.TABLESAMPLE_CSV:
+ num = None
+ expressions = self._parse_csv(self._parse_primary)
+ else:
+ expressions = None
+ num = self._parse_number()
if self._match_text_seq("BUCKET"):
bucket_numerator = self._parse_number()
@@ -2684,7 +2701,7 @@ class Parser(metaclass=_Parser):
percent = num
elif self._match(TokenType.ROWS):
rows = num
- else:
+ elif num:
size = num
self._match(TokenType.R_PAREN)
@@ -2698,6 +2715,7 @@ class Parser(metaclass=_Parser):
return self.expression(
exp.TableSample,
+ expressions=expressions,
method=method,
bucket_numerator=bucket_numerator,
bucket_denominator=bucket_denominator,
@@ -3325,15 +3343,14 @@ class Parser(metaclass=_Parser):
elif self._match_text_seq("WITHOUT", "TIME", "ZONE"):
maybe_func = False
elif type_token == TokenType.INTERVAL:
- if self._match_text_seq("YEAR", "TO", "MONTH"):
- span: t.Optional[t.List[exp.Expression]] = [exp.IntervalYearToMonthSpan()]
- elif self._match_text_seq("DAY", "TO", "SECOND"):
- span = [exp.IntervalDayToSecondSpan()]
+ unit = self._parse_var()
+
+ if self._match_text_seq("TO"):
+ span = [exp.IntervalSpan(this=unit, expression=self._parse_var())]
else:
span = None
- unit = not span and self._parse_var()
- if not unit:
+ if span or not unit:
this = self.expression(
exp.DataType, this=exp.DataType.Type.INTERVAL, expressions=span
)
@@ -3351,6 +3368,13 @@ class Parser(metaclass=_Parser):
self._retreat(index2)
if not this:
+ if self._match_text_seq("UNSIGNED"):
+ unsigned_type_token = self.SIGNED_TO_UNSIGNED_TYPE_TOKEN.get(type_token)
+ if not unsigned_type_token:
+ self.raise_error(f"Cannot convert {type_token.value} to unsigned.")
+
+ type_token = unsigned_type_token or type_token
+
this = exp.DataType(
this=exp.DataType.Type[type_token.value],
expressions=expressions,
@@ -4761,6 +4785,7 @@ class Parser(metaclass=_Parser):
return self._parse_as_command(start)
exists = self._parse_exists()
+ only = self._match_text_seq("ONLY")
this = self._parse_table(schema=True)
if self._next:
@@ -4776,7 +4801,9 @@ class Parser(metaclass=_Parser):
this=this,
exists=exists,
actions=actions,
+ only=only,
)
+
return self._parse_as_command(start)
def _parse_merge(self) -> exp.Merge:
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 83b97d6..3ba8195 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -86,6 +86,7 @@ class TokenType(AutoName):
SMALLINT = auto()
USMALLINT = auto()
MEDIUMINT = auto()
+ UMEDIUMINT = auto()
INT = auto()
UINT = auto()
BIGINT = auto()
diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py
index 48ea8dc..66ab884 100644
--- a/sqlglot/transforms.py
+++ b/sqlglot/transforms.py
@@ -76,7 +76,7 @@ def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression:
return (
exp.select(*outer_selects)
- .from_(expression.subquery())
+ .from_(expression.subquery("_t"))
.where(exp.column(row_number).eq(1))
)