summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-23 08:42:55 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2023-01-23 08:42:55 +0000
commitade4a78e8fabcaa7270b6d4be2187457a3fa115f (patch)
tree018225e76010479b3a568bb6d9ef5df457802885
parentAdding upstream version 10.5.2. (diff)
downloadsqlglot-upstream/10.5.6.tar.xz
sqlglot-upstream/10.5.6.zip
Adding upstream version 10.5.6.upstream/10.5.6
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
-rw-r--r--README.md2
-rw-r--r--sqlglot/__init__.py2
-rw-r--r--sqlglot/dialects/__init__.py1
-rw-r--r--sqlglot/dialects/bigquery.py5
-rw-r--r--sqlglot/dialects/clickhouse.py31
-rw-r--r--sqlglot/dialects/dialect.py3
-rw-r--r--sqlglot/dialects/hive.py4
-rw-r--r--sqlglot/dialects/mysql.py10
-rw-r--r--sqlglot/dialects/postgres.py5
-rw-r--r--sqlglot/dialects/presto.py62
-rw-r--r--sqlglot/dialects/redshift.py73
-rw-r--r--sqlglot/dialects/snowflake.py3
-rw-r--r--sqlglot/dialects/spark.py1
-rw-r--r--sqlglot/dialects/sqlite.py16
-rw-r--r--sqlglot/dialects/teradata.py87
-rw-r--r--sqlglot/dialects/tsql.py28
-rw-r--r--sqlglot/expressions.py68
-rw-r--r--sqlglot/generator.py82
-rw-r--r--sqlglot/optimizer/optimizer.py5
-rw-r--r--sqlglot/optimizer/qualify_columns.py4
-rw-r--r--sqlglot/parser.py80
-rw-r--r--sqlglot/tokens.py40
-rw-r--r--tests/dialects/test_bigquery.py11
-rw-r--r--tests/dialects/test_clickhouse.py7
-rw-r--r--tests/dialects/test_dialect.py31
-rw-r--r--tests/dialects/test_hive.py26
-rw-r--r--tests/dialects/test_mysql.py11
-rw-r--r--tests/dialects/test_postgres.py8
-rw-r--r--tests/dialects/test_presto.py13
-rw-r--r--tests/dialects/test_redshift.py82
-rw-r--r--tests/dialects/test_spark.py9
-rw-r--r--tests/dialects/test_teradata.py23
-rw-r--r--tests/dialects/test_tsql.py7
-rw-r--r--tests/fixtures/identity.sql29
-rw-r--r--tests/fixtures/pretty.sql20
-rw-r--r--tests/test_expressions.py10
-rw-r--r--tests/test_optimizer.py4
-rw-r--r--tests/test_parser.py6
-rw-r--r--tests/test_transpile.py8
39 files changed, 785 insertions, 132 deletions
diff --git a/README.md b/README.md
index 85a76e5..0416521 100644
--- a/README.md
+++ b/README.md
@@ -462,7 +462,7 @@ make check # Set SKIP_INTEGRATION=1 to skip integration tests
| Query | sqlglot | sqlfluff | sqltree | sqlparse | moz_sql_parser | sqloxide |
| --------------- | --------------- | --------------- | --------------- | --------------- | --------------- | --------------- |
| tpch | 0.01308 (1.0) | 1.60626 (122.7) | 0.01168 (0.893) | 0.04958 (3.791) | 0.08543 (6.531) | 0.00136 (0.104) |
-| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76621 (0.080) |
+| short | 0.00109 (1.0) | 0.14134 (129.2) | 0.00099 (0.906) | 0.00342 (3.131) | 0.00652 (5.970) | 8.76E-5 (0.080) |
| long | 0.01399 (1.0) | 2.12632 (151.9) | 0.01126 (0.805) | 0.04410 (3.151) | 0.06671 (4.767) | 0.00107 (0.076) |
| crazy | 0.03969 (1.0) | 24.3777 (614.1) | 0.03917 (0.987) | 11.7043 (294.8) | 1.03280 (26.02) | 0.00625 (0.157) |
diff --git a/sqlglot/__init__.py b/sqlglot/__init__.py
index 87fa081..f2db4f1 100644
--- a/sqlglot/__init__.py
+++ b/sqlglot/__init__.py
@@ -32,7 +32,7 @@ from sqlglot.parser import Parser
from sqlglot.schema import MappingSchema
from sqlglot.tokens import Tokenizer, TokenType
-__version__ = "10.5.2"
+__version__ = "10.5.6"
pretty = False
diff --git a/sqlglot/dialects/__init__.py b/sqlglot/dialects/__init__.py
index 2e42e7d..2084681 100644
--- a/sqlglot/dialects/__init__.py
+++ b/sqlglot/dialects/__init__.py
@@ -15,5 +15,6 @@ from sqlglot.dialects.spark import Spark
from sqlglot.dialects.sqlite import SQLite
from sqlglot.dialects.starrocks import StarRocks
from sqlglot.dialects.tableau import Tableau
+from sqlglot.dialects.teradata import Teradata
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.tsql import TSQL
diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py
index f0089e1..9ddfbea 100644
--- a/sqlglot/dialects/bigquery.py
+++ b/sqlglot/dialects/bigquery.py
@@ -165,6 +165,11 @@ class BigQuery(Dialect):
TokenType.TABLE,
}
+ ID_VAR_TOKENS = {
+ *parser.Parser.ID_VAR_TOKENS, # type: ignore
+ TokenType.VALUES,
+ }
+
class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py
index 04d46d2..1c173a4 100644
--- a/sqlglot/dialects/clickhouse.py
+++ b/sqlglot/dialects/clickhouse.py
@@ -4,6 +4,7 @@ import typing as t
from sqlglot import exp, generator, parser, tokens
from sqlglot.dialects.dialect import Dialect, inline_array_sql, var_map_sql
+from sqlglot.errors import ParseError
from sqlglot.parser import parse_var_map
from sqlglot.tokens import TokenType
@@ -72,6 +73,30 @@ class ClickHouse(Dialect):
return this
+ def _parse_position(self) -> exp.Expression:
+ this = super()._parse_position()
+ # clickhouse position args are swapped
+ substr = this.this
+ this.args["this"] = this.args.get("substr")
+ this.args["substr"] = substr
+ return this
+
+ # https://clickhouse.com/docs/en/sql-reference/statements/select/with/
+ def _parse_cte(self) -> exp.Expression:
+ index = self._index
+ try:
+ # WITH <identifier> AS <subquery expression>
+ return super()._parse_cte()
+ except ParseError:
+ # WITH <expression> AS <identifier>
+ self._retreat(index)
+ statement = self._parse_statement()
+
+ if statement and isinstance(statement.this, exp.Alias):
+ self.raise_error("Expected CTE to have alias")
+
+ return self.expression(exp.CTE, this=statement, alias=statement and statement.this)
+
class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
@@ -110,3 +135,9 @@ class ClickHouse(Dialect):
params = self.format_args(self.expressions(expression, params_name))
args = self.format_args(self.expressions(expression, args_name))
return f"({params})({args})"
+
+ def cte_sql(self, expression: exp.CTE) -> str:
+ if isinstance(expression.this, exp.Alias):
+ return self.sql(expression, "this")
+
+ return super().cte_sql(expression)
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py
index 1c840da..0c2beba 100644
--- a/sqlglot/dialects/dialect.py
+++ b/sqlglot/dialects/dialect.py
@@ -33,6 +33,7 @@ class Dialects(str, Enum):
TSQL = "tsql"
DATABRICKS = "databricks"
DRILL = "drill"
+ TERADATA = "teradata"
class _Dialect(type):
@@ -368,7 +369,7 @@ def locate_to_strposition(args):
)
-def strposition_to_local_sql(self, expression):
+def strposition_to_locate_sql(self, expression):
args = self.format_args(
expression.args.get("substr"), expression.this, expression.args.get("position")
)
diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py
index ead13b1..ddfd1e8 100644
--- a/sqlglot/dialects/hive.py
+++ b/sqlglot/dialects/hive.py
@@ -13,7 +13,7 @@ from sqlglot.dialects.dialect import (
no_safe_divide_sql,
no_trycast_sql,
rename_func,
- strposition_to_local_sql,
+ strposition_to_locate_sql,
struct_extract_sql,
timestrtotime_sql,
var_map_sql,
@@ -297,7 +297,7 @@ class Hive(Dialect):
exp.SchemaCommentProperty: lambda self, e: self.naked_property(e),
exp.SetAgg: rename_func("COLLECT_SET"),
exp.Split: lambda self, e: f"SPLIT({self.sql(e, 'this')}, CONCAT('\\\\Q', {self.sql(e, 'expression')}))",
- exp.StrPosition: strposition_to_local_sql,
+ exp.StrPosition: strposition_to_locate_sql,
exp.StrToDate: _str_to_date,
exp.StrToTime: _str_to_time,
exp.StrToUnix: _str_to_unix,
diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py
index 0fd7992..1bddfe1 100644
--- a/sqlglot/dialects/mysql.py
+++ b/sqlglot/dialects/mysql.py
@@ -10,7 +10,7 @@ from sqlglot.dialects.dialect import (
no_paren_current_date_sql,
no_tablesample_sql,
no_trycast_sql,
- strposition_to_local_sql,
+ strposition_to_locate_sql,
)
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType
@@ -122,6 +122,8 @@ class MySQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
+ "MEDIUMTEXT": TokenType.MEDIUMTEXT,
+ "LONGTEXT": TokenType.LONGTEXT,
"START": TokenType.BEGIN,
"SEPARATOR": TokenType.SEPARATOR,
"_ARMSCII8": TokenType.INTRODUCER,
@@ -442,7 +444,7 @@ class MySQL(Dialect):
exp.Trim: _trim_sql,
exp.NullSafeEQ: lambda self, e: self.binary(e, "<=>"),
exp.NullSafeNEQ: lambda self, e: self.not_sql(self.binary(e, "<=>")),
- exp.StrPosition: strposition_to_local_sql,
+ exp.StrPosition: strposition_to_locate_sql,
}
ROOT_PROPERTIES = {
@@ -454,6 +456,10 @@ class MySQL(Dialect):
exp.LikeProperty,
}
+ TYPE_MAPPING = generator.Generator.TYPE_MAPPING.copy()
+ TYPE_MAPPING.pop(exp.DataType.Type.MEDIUMTEXT)
+ TYPE_MAPPING.pop(exp.DataType.Type.LONGTEXT)
+
WITH_PROPERTIES: t.Set[t.Type[exp.Property]] = set()
def show_sql(self, expression):
diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py
index f3fec31..6f597f1 100644
--- a/sqlglot/dialects/postgres.py
+++ b/sqlglot/dialects/postgres.py
@@ -223,19 +223,15 @@ class Postgres(Dialect):
"~~*": TokenType.ILIKE,
"~*": TokenType.IRLIKE,
"~": TokenType.RLIKE,
- "ALWAYS": TokenType.ALWAYS,
"BEGIN": TokenType.COMMAND,
"BEGIN TRANSACTION": TokenType.BEGIN,
"BIGSERIAL": TokenType.BIGSERIAL,
- "BY DEFAULT": TokenType.BY_DEFAULT,
"CHARACTER VARYING": TokenType.VARCHAR,
"COMMENT ON": TokenType.COMMAND,
"DECLARE": TokenType.COMMAND,
"DO": TokenType.COMMAND,
- "GENERATED": TokenType.GENERATED,
"GRANT": TokenType.COMMAND,
"HSTORE": TokenType.HSTORE,
- "IDENTITY": TokenType.IDENTITY,
"JSONB": TokenType.JSONB,
"REFRESH": TokenType.COMMAND,
"REINDEX": TokenType.COMMAND,
@@ -299,6 +295,7 @@ class Postgres(Dialect):
exp.StrPosition: str_position_sql,
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Substring: _substring_sql,
+ exp.TimeStrToTime: lambda self, e: f"CAST({self.sql(e, 'this')} AS TIMESTAMP)",
exp.TimeToStr: lambda self, e: f"TO_CHAR({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TableSample: no_tablesample_sql,
exp.Trim: trim_sql,
diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py
index e16ea1d..a79a9f9 100644
--- a/sqlglot/dialects/presto.py
+++ b/sqlglot/dialects/presto.py
@@ -8,7 +8,6 @@ from sqlglot.dialects.dialect import (
no_ilike_sql,
no_safe_divide_sql,
rename_func,
- str_position_sql,
struct_extract_sql,
timestrtotime_sql,
)
@@ -24,14 +23,6 @@ def _approx_distinct_sql(self, expression):
return f"APPROX_DISTINCT({self.sql(expression, 'this')}{accuracy})"
-def _concat_ws_sql(self, expression):
- sep, *args = expression.expressions
- sep = self.sql(sep)
- if len(args) > 1:
- return f"ARRAY_JOIN(ARRAY[{self.format_args(*args)}], {sep})"
- return f"ARRAY_JOIN({self.sql(args[0])}, {sep})"
-
-
def _datatype_sql(self, expression):
sql = self.datatype_sql(expression)
if expression.this == exp.DataType.Type.TIMESTAMPTZ:
@@ -61,7 +52,7 @@ def _initcap_sql(self, expression):
def _decode_sql(self, expression):
_ensure_utf8(expression.args.get("charset"))
- return f"FROM_UTF8({self.sql(expression, 'this')})"
+ return f"FROM_UTF8({self.format_args(expression.this, expression.args.get('replace'))})"
def _encode_sql(self, expression):
@@ -119,6 +110,38 @@ def _ensure_utf8(charset):
raise UnsupportedError(f"Unsupported charset {charset}")
+def _approx_percentile(args):
+ if len(args) == 4:
+ return exp.ApproxQuantile(
+ this=seq_get(args, 0),
+ weight=seq_get(args, 1),
+ quantile=seq_get(args, 2),
+ accuracy=seq_get(args, 3),
+ )
+ if len(args) == 3:
+ return exp.ApproxQuantile(
+ this=seq_get(args, 0),
+ quantile=seq_get(args, 1),
+ accuracy=seq_get(args, 2),
+ )
+ return exp.ApproxQuantile.from_arg_list(args)
+
+
+def _from_unixtime(args):
+ if len(args) == 3:
+ return exp.UnixToTime(
+ this=seq_get(args, 0),
+ hours=seq_get(args, 1),
+ minutes=seq_get(args, 2),
+ )
+ if len(args) == 2:
+ return exp.UnixToTime(
+ this=seq_get(args, 0),
+ zone=seq_get(args, 1),
+ )
+ return exp.UnixToTime.from_arg_list(args)
+
+
class Presto(Dialect):
index_offset = 1
null_ordering = "nulls_are_last"
@@ -150,19 +173,25 @@ class Presto(Dialect):
),
"DATE_FORMAT": format_time_lambda(exp.TimeToStr, "presto"),
"DATE_PARSE": format_time_lambda(exp.StrToTime, "presto"),
- "FROM_UNIXTIME": exp.UnixToTime.from_arg_list,
- "STRPOS": exp.StrPosition.from_arg_list,
+ "FROM_UNIXTIME": _from_unixtime,
+ "STRPOS": lambda args: exp.StrPosition(
+ this=seq_get(args, 0),
+ substr=seq_get(args, 1),
+ instance=seq_get(args, 2),
+ ),
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
- "APPROX_PERCENTILE": exp.ApproxQuantile.from_arg_list,
+ "APPROX_PERCENTILE": _approx_percentile,
"FROM_HEX": exp.Unhex.from_arg_list,
"TO_HEX": exp.Hex.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
),
"FROM_UTF8": lambda args: exp.Decode(
- this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
+ this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8")
),
}
+ FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
+ FUNCTION_PARSERS.pop("TRIM")
class Generator(generator.Generator):
@@ -194,7 +223,6 @@ class Presto(Dialect):
exp.BitwiseOr: lambda self, e: f"BITWISE_OR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseRightShift: lambda self, e: f"BITWISE_ARITHMETIC_SHIFT_RIGHT({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
exp.BitwiseXor: lambda self, e: f"BITWISE_XOR({self.sql(e, 'this')}, {self.sql(e, 'expression')})",
- exp.ConcatWs: _concat_ws_sql,
exp.DataType: _datatype_sql,
exp.DateAdd: lambda self, e: f"""DATE_ADD({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
exp.DateDiff: lambda self, e: f"""DATE_DIFF({self.sql(e, 'unit') or "'day'"}, {self.sql(e, 'expression')}, {self.sql(e, 'this')})""",
@@ -209,12 +237,13 @@ class Presto(Dialect):
exp.Initcap: _initcap_sql,
exp.Lateral: _explode_to_unnest_sql,
exp.Levenshtein: rename_func("LEVENSHTEIN_DISTANCE"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
exp.Quantile: _quantile_sql,
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.SafeDivide: no_safe_divide_sql,
exp.Schema: _schema_sql,
exp.SortArray: _no_sort_array,
- exp.StrPosition: str_position_sql,
+ exp.StrPosition: rename_func("STRPOS"),
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: f"TO_UNIXTIME(DATE_PARSE({self.sql(e, 'this')}, {self.format_time(e)}))",
@@ -233,6 +262,7 @@ class Presto(Dialect):
exp.UnixToStr: lambda self, e: f"DATE_FORMAT(FROM_UNIXTIME({self.sql(e, 'this')}), {self.format_time(e)})",
exp.UnixToTime: rename_func("FROM_UNIXTIME"),
exp.UnixToTimeStr: lambda self, e: f"CAST(FROM_UNIXTIME({self.sql(e, 'this')}) AS VARCHAR)",
+ exp.VariancePop: rename_func("VAR_POP"),
}
def transaction_sql(self, expression):
diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py
index 27dfb93..afd7913 100644
--- a/sqlglot/dialects/redshift.py
+++ b/sqlglot/dialects/redshift.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import typing as t
+
from sqlglot import exp, transforms
from sqlglot.dialects.dialect import rename_func
from sqlglot.dialects.postgres import Postgres
@@ -21,6 +23,19 @@ class Redshift(Postgres):
"NVL": exp.Coalesce.from_arg_list,
}
+ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
+ this = super()._parse_types(check_func=check_func)
+
+ if (
+ isinstance(this, exp.DataType)
+ and this.this == exp.DataType.Type.VARCHAR
+ and this.expressions
+ and this.expressions[0] == exp.column("MAX")
+ ):
+ this.set("expressions", [exp.Var(this="MAX")])
+
+ return this
+
class Tokenizer(Postgres.Tokenizer):
ESCAPES = ["\\"]
@@ -52,6 +67,10 @@ class Redshift(Postgres):
exp.DistStyleProperty,
}
+ WITH_PROPERTIES = {
+ exp.LikeProperty,
+ }
+
TRANSFORMS = {
**Postgres.Generator.TRANSFORMS, # type: ignore
**transforms.ELIMINATE_DISTINCT_ON, # type: ignore
@@ -60,3 +79,57 @@ class Redshift(Postgres):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.Matches: rename_func("DECODE"),
}
+
+ def values_sql(self, expression: exp.Values) -> str:
+ """
+ Converts `VALUES...` expression into a series of unions.
+
+ Note: If you have a lot of unions then this will result in a large number of recursive statements to
+ evaluate the expression. You may need to increase `sys.setrecursionlimit` to run and it can also be
+ very slow.
+ """
+ if not isinstance(expression.unnest().parent, exp.From):
+ return super().values_sql(expression)
+ rows = [tuple_exp.expressions for tuple_exp in expression.expressions]
+ selects = []
+ for i, row in enumerate(rows):
+ if i == 0:
+ row = [
+ exp.alias_(value, column_name)
+ for value, column_name in zip(row, expression.args["alias"].args["columns"])
+ ]
+ selects.append(exp.Select(expressions=row))
+ subquery_expression = selects[0]
+ if len(selects) > 1:
+ for select in selects[1:]:
+ subquery_expression = exp.union(subquery_expression, select, distinct=False)
+ return self.subquery_sql(subquery_expression.subquery(expression.alias))
+
+ def with_properties(self, properties: exp.Properties) -> str:
+ """Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
+ return self.properties(properties, prefix=" ", suffix="")
+
+ def renametable_sql(self, expression: exp.RenameTable) -> str:
+ """Redshift only supports defining the table name itself (not the db) when renaming tables"""
+ expression = expression.copy()
+ target_table = expression.this
+ for arg in target_table.args:
+ if arg != "this":
+ target_table.set(arg, None)
+ this = self.sql(expression, "this")
+ return f"RENAME TO {this}"
+
+ def datatype_sql(self, expression: exp.DataType) -> str:
+ """
+ Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
+ VARCHAR of max length which is `VARCHAR(max)` in Redshift. Therefore if we get a `TEXT` data type
+ without precision we convert it to `VARCHAR(max)` and if it does have precision then we just convert
+ `TEXT` to `VARCHAR`.
+ """
+ if expression.this == exp.DataType.Type.TEXT:
+ expression = expression.copy()
+ expression.set("this", exp.DataType.Type.VARCHAR)
+ precision = expression.args.get("expressions")
+ if not precision:
+ expression.append("expressions", exp.Var(this="MAX"))
+ return super().datatype_sql(expression)
diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py
index 24d3bdf..c44950a 100644
--- a/sqlglot/dialects/snowflake.py
+++ b/sqlglot/dialects/snowflake.py
@@ -210,6 +210,7 @@ class Snowflake(Dialect):
**generator.Generator.TRANSFORMS, # type: ignore
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
+ exp.DateAdd: rename_func("DATEADD"),
exp.DateStrToDate: datestrtodate_sql,
exp.DataType: _datatype_sql,
exp.If: rename_func("IFF"),
@@ -218,7 +219,7 @@ class Snowflake(Dialect):
exp.Parameter: lambda self, e: f"${self.sql(e, 'this')}",
exp.PartitionedByProperty: lambda self, e: f"PARTITION BY {self.sql(e, 'this')}",
exp.Matches: rename_func("DECODE"),
- exp.StrPosition: rename_func("POSITION"),
+ exp.StrPosition: lambda self, e: f"{self.normalize_func('POSITION')}({self.format_args(e.args.get('substr'), e.this, e.args.get('position'))})",
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
diff --git a/sqlglot/dialects/spark.py b/sqlglot/dialects/spark.py
index 7f05dea..42d34c2 100644
--- a/sqlglot/dialects/spark.py
+++ b/sqlglot/dialects/spark.py
@@ -124,6 +124,7 @@ class Spark(Hive):
exp.TimestampTrunc: lambda self, e: f"DATE_TRUNC({self.sql(e, 'unit')}, {self.sql(e, 'this')})",
exp.VariancePop: rename_func("VAR_POP"),
exp.DateFromParts: rename_func("MAKE_DATE"),
+ exp.LogicalOr: rename_func("BOOL_OR"),
}
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
diff --git a/sqlglot/dialects/sqlite.py b/sqlglot/dialects/sqlite.py
index a0c4942..1b39449 100644
--- a/sqlglot/dialects/sqlite.py
+++ b/sqlglot/dialects/sqlite.py
@@ -13,6 +13,10 @@ from sqlglot.dialects.dialect import (
from sqlglot.tokens import TokenType
+def _fetch_sql(self, expression):
+ return self.limit_sql(exp.Limit(expression=expression.args.get("count")))
+
+
# https://www.sqlite.org/lang_aggfunc.html#group_concat
def _group_concat_sql(self, expression):
this = expression.this
@@ -30,6 +34,14 @@ def _group_concat_sql(self, expression):
return f"GROUP_CONCAT({distinct or ''}{self.format_args(this, separator)})"
+def _date_add_sql(self, expression):
+ modifier = expression.expression
+ modifier = expression.name if modifier.is_string else self.sql(modifier)
+ unit = expression.args.get("unit")
+ modifier = f"'{modifier} {unit.name}'" if unit else f"'{modifier}'"
+ return f"{self.normalize_func('DATE')}({self.format_args(expression.this, modifier)})"
+
+
class SQLite(Dialect):
class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]"), "`"]
@@ -71,6 +83,7 @@ class SQLite(Dialect):
TRANSFORMS = {
**generator.Generator.TRANSFORMS, # type: ignore
+ exp.DateAdd: _date_add_sql,
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
@@ -78,8 +91,11 @@ class SQLite(Dialect):
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.TableSample: no_tablesample_sql,
+ exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
+ exp.TimeStrToTime: lambda self, e: self.sql(e, "this"),
exp.TryCast: no_trycast_sql,
exp.GroupConcat: _group_concat_sql,
+ exp.Fetch: _fetch_sql,
}
def transaction_sql(self, expression):
diff --git a/sqlglot/dialects/teradata.py b/sqlglot/dialects/teradata.py
new file mode 100644
index 0000000..4340820
--- /dev/null
+++ b/sqlglot/dialects/teradata.py
@@ -0,0 +1,87 @@
+from __future__ import annotations
+
+from sqlglot import exp, generator, parser
+from sqlglot.dialects.dialect import Dialect
+from sqlglot.tokens import TokenType
+
+
+class Teradata(Dialect):
+ class Parser(parser.Parser):
+ CHARSET_TRANSLATORS = {
+ "GRAPHIC_TO_KANJISJIS",
+ "GRAPHIC_TO_LATIN",
+ "GRAPHIC_TO_UNICODE",
+ "GRAPHIC_TO_UNICODE_PadSpace",
+ "KANJI1_KanjiEBCDIC_TO_UNICODE",
+ "KANJI1_KanjiEUC_TO_UNICODE",
+ "KANJI1_KANJISJIS_TO_UNICODE",
+ "KANJI1_SBC_TO_UNICODE",
+ "KANJISJIS_TO_GRAPHIC",
+ "KANJISJIS_TO_LATIN",
+ "KANJISJIS_TO_UNICODE",
+ "LATIN_TO_GRAPHIC",
+ "LATIN_TO_KANJISJIS",
+ "LATIN_TO_UNICODE",
+ "LOCALE_TO_UNICODE",
+ "UNICODE_TO_GRAPHIC",
+ "UNICODE_TO_GRAPHIC_PadGraphic",
+ "UNICODE_TO_GRAPHIC_VarGraphic",
+ "UNICODE_TO_KANJI1_KanjiEBCDIC",
+ "UNICODE_TO_KANJI1_KanjiEUC",
+ "UNICODE_TO_KANJI1_KANJISJIS",
+ "UNICODE_TO_KANJI1_SBC",
+ "UNICODE_TO_KANJISJIS",
+ "UNICODE_TO_LATIN",
+ "UNICODE_TO_LOCALE",
+ "UNICODE_TO_UNICODE_FoldSpace",
+ "UNICODE_TO_UNICODE_Fullwidth",
+ "UNICODE_TO_UNICODE_Halfwidth",
+ "UNICODE_TO_UNICODE_NFC",
+ "UNICODE_TO_UNICODE_NFD",
+ "UNICODE_TO_UNICODE_NFKC",
+ "UNICODE_TO_UNICODE_NFKD",
+ }
+
+ FUNCTION_PARSERS = {
+ **parser.Parser.FUNCTION_PARSERS, # type: ignore
+ "TRANSLATE": lambda self: self._parse_translate(self.STRICT_CAST),
+ }
+
+ def _parse_translate(self, strict: bool) -> exp.Expression:
+ this = self._parse_conjunction()
+
+ if not self._match(TokenType.USING):
+ self.raise_error("Expected USING in TRANSLATE")
+
+ if self._match_texts(self.CHARSET_TRANSLATORS):
+ charset_split = self._prev.text.split("_TO_")
+ to = self.expression(exp.CharacterSet, this=charset_split[1])
+ else:
+ self.raise_error("Expected a character set translator after USING in TRANSLATE")
+
+ return self.expression(exp.Cast if strict else exp.TryCast, this=this, to=to)
+
+ # FROM before SET in Teradata UPDATE syntax
+ # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
+ def _parse_update(self) -> exp.Expression:
+ return self.expression(
+ exp.Update,
+ **{ # type: ignore
+ "this": self._parse_table(alias_tokens=self.UPDATE_ALIAS_TOKENS),
+ "from": self._parse_from(),
+ "expressions": self._match(TokenType.SET)
+ and self._parse_csv(self._parse_equality),
+ "where": self._parse_where(),
+ },
+ )
+
+ class Generator(generator.Generator):
+ # FROM before SET in Teradata UPDATE syntax
+ # https://docs.teradata.com/r/Enterprise_IntelliFlex_VMware/Teradata-VantageTM-SQL-Data-Manipulation-Language-17.20/Statement-Syntax/UPDATE/UPDATE-Syntax-Basic-Form-FROM-Clause
+ def update_sql(self, expression: exp.Update) -> str:
+ this = self.sql(expression, "this")
+ from_sql = self.sql(expression, "from")
+ set_sql = self.expressions(expression, flat=True)
+ where_sql = self.sql(expression, "where")
+ sql = f"UPDATE {this}{from_sql} SET {set_sql}{where_sql}"
+ return self.prepend_ctes(expression, sql)
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index 465f534..9342e6b 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -243,28 +243,34 @@ class TSQL(Dialect):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
"BIT": TokenType.BOOLEAN,
- "REAL": TokenType.FLOAT,
- "NTEXT": TokenType.TEXT,
- "SMALLDATETIME": TokenType.DATETIME,
"DATETIME2": TokenType.DATETIME,
"DATETIMEOFFSET": TokenType.TIMESTAMPTZ,
- "TIME": TokenType.TIMESTAMP,
+ "DECLARE": TokenType.COMMAND,
"IMAGE": TokenType.IMAGE,
"MONEY": TokenType.MONEY,
- "SMALLMONEY": TokenType.SMALLMONEY,
+ "NTEXT": TokenType.TEXT,
+ "NVARCHAR(MAX)": TokenType.TEXT,
+ "PRINT": TokenType.COMMAND,
+ "REAL": TokenType.FLOAT,
"ROWVERSION": TokenType.ROWVERSION,
- "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
- "XML": TokenType.XML,
+ "SMALLDATETIME": TokenType.DATETIME,
+ "SMALLMONEY": TokenType.SMALLMONEY,
"SQL_VARIANT": TokenType.VARIANT,
- "NVARCHAR(MAX)": TokenType.TEXT,
- "VARCHAR(MAX)": TokenType.TEXT,
+ "TIME": TokenType.TIMESTAMP,
"TOP": TokenType.TOP,
+ "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER,
+ "VARCHAR(MAX)": TokenType.TEXT,
+ "XML": TokenType.XML,
}
class Parser(parser.Parser):
FUNCTIONS = {
**parser.Parser.FUNCTIONS, # type: ignore
- "CHARINDEX": exp.StrPosition.from_arg_list,
+ "CHARINDEX": lambda args: exp.StrPosition(
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ ),
"ISNULL": exp.Coalesce.from_arg_list,
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@@ -288,7 +294,7 @@ class TSQL(Dialect):
}
# https://learn.microsoft.com/en-us/azure/synapse-analytics/sql-data-warehouse/sql-data-warehouse-tables-temporary#create-a-temporary-table
- TABLE_PREFIX_TOKENS = {TokenType.HASH}
+ TABLE_PREFIX_TOKENS = {TokenType.HASH, TokenType.PARAMETER}
def _parse_convert(self, strict):
to = self._parse_types()
diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py
index d093e29..be99fe2 100644
--- a/sqlglot/expressions.py
+++ b/sqlglot/expressions.py
@@ -653,6 +653,7 @@ class Create(Expression):
"statistics": False,
"no_primary_index": False,
"indexes": False,
+ "no_schema_binding": False,
}
@@ -770,6 +771,10 @@ class AlterColumn(Expression):
}
+class RenameTable(Expression):
+ pass
+
+
class ColumnConstraint(Expression):
arg_types = {"this": False, "kind": True}
@@ -804,7 +809,7 @@ class EncodeColumnConstraint(ColumnConstraintKind):
class GeneratedAsIdentityColumnConstraint(ColumnConstraintKind):
# this: True -> ALWAYS, this: False -> BY DEFAULT
- arg_types = {"this": True, "expression": False}
+ arg_types = {"this": True, "start": False, "increment": False}
class NotNullColumnConstraint(ColumnConstraintKind):
@@ -1266,7 +1271,7 @@ class Tuple(Expression):
class Subqueryable(Unionable):
- def subquery(self, alias=None, copy=True):
+ def subquery(self, alias=None, copy=True) -> Subquery:
"""
Convert this expression to an aliased expression that can be used as a Subquery.
@@ -1460,6 +1465,7 @@ class Unnest(UDTF):
"expressions": True,
"ordinality": False,
"alias": False,
+ "offset": False,
}
@@ -2126,6 +2132,7 @@ class DataType(Expression):
"this": True,
"expressions": False,
"nested": False,
+ "values": False,
}
class Type(AutoName):
@@ -2134,6 +2141,8 @@ class DataType(Expression):
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
+ MEDIUMTEXT = auto()
+ LONGTEXT = auto()
BINARY = auto()
VARBINARY = auto()
INT = auto()
@@ -2791,7 +2800,7 @@ class Day(Func):
class Decode(Func):
- arg_types = {"this": True, "charset": True}
+ arg_types = {"this": True, "charset": True, "replace": False}
class DiToDate(Func):
@@ -2815,7 +2824,7 @@ class Floor(Func):
class Greatest(Func):
- arg_types = {"this": True, "expressions": True}
+ arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -2861,7 +2870,7 @@ class JSONBExtractScalar(JSONExtract):
class Least(Func):
- arg_types = {"this": True, "expressions": True}
+ arg_types = {"this": True, "expressions": False}
is_var_len_args = True
@@ -2904,7 +2913,7 @@ class Lower(Func):
class Map(Func):
- arg_types = {"keys": True, "values": True}
+ arg_types = {"keys": False, "values": False}
class VarMap(Func):
@@ -2923,11 +2932,11 @@ class Matches(Func):
class Max(AggFunc):
- pass
+ arg_types = {"this": True, "expression": False}
class Min(AggFunc):
- pass
+ arg_types = {"this": True, "expression": False}
class Month(Func):
@@ -2962,7 +2971,7 @@ class QuantileIf(AggFunc):
class ApproxQuantile(Quantile):
- arg_types = {"this": True, "quantile": True, "accuracy": False}
+ arg_types = {"this": True, "quantile": True, "accuracy": False, "weight": False}
class ReadCSV(Func):
@@ -3022,7 +3031,12 @@ class Substring(Func):
class StrPosition(Func):
- arg_types = {"substr": True, "this": True, "position": False}
+ arg_types = {
+ "this": True,
+ "substr": True,
+ "position": False,
+ "instance": False,
+ }
class StrToDate(Func):
@@ -3129,8 +3143,10 @@ class UnixToStr(Func):
arg_types = {"this": True, "format": False}
+# https://prestodb.io/docs/current/functions/datetime.html
+# presto has weird zone/hours/minutes
class UnixToTime(Func):
- arg_types = {"this": True, "scale": False}
+ arg_types = {"this": True, "scale": False, "zone": False, "hours": False, "minutes": False}
SECONDS = Literal.string("seconds")
MILLIS = Literal.string("millis")
@@ -3684,6 +3700,16 @@ def to_identifier(alias, quoted=None) -> t.Optional[Identifier]:
return identifier
+@t.overload
+def to_table(sql_path: str | Table, **kwargs) -> Table:
+ ...
+
+
+@t.overload
+def to_table(sql_path: None, **kwargs) -> None:
+ ...
+
+
def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
@@ -3860,6 +3886,26 @@ def values(
)
+def rename_table(old_name: str | Table, new_name: str | Table) -> AlterTable:
+ """Build ALTER TABLE... RENAME... expression
+
+ Args:
+ old_name: The old name of the table
+ new_name: The new name of the table
+
+ Returns:
+ Alter table expression
+ """
+ old_table = to_table(old_name)
+ new_table = to_table(new_name)
+ return AlterTable(
+ this=old_table,
+ actions=[
+ RenameTable(this=new_table),
+ ],
+ )
+
+
def convert(value) -> Expression:
"""Convert a python value into an expression object.
diff --git a/sqlglot/generator.py b/sqlglot/generator.py
index 3935133..6375d92 100644
--- a/sqlglot/generator.py
+++ b/sqlglot/generator.py
@@ -82,6 +82,8 @@ class Generator:
TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
+ exp.DataType.Type.MEDIUMTEXT: "TEXT",
+ exp.DataType.Type.LONGTEXT: "TEXT",
}
TOKEN_MAPPING: t.Dict[TokenType, str] = {}
@@ -105,6 +107,7 @@ class Generator:
}
WITH_SEPARATED_COMMENTS = (exp.Select, exp.From, exp.Where, exp.Binary)
+ SENTINEL_LINE_BREAK = "__SQLGLOT__LB__"
__slots__ = (
"time_mapping",
@@ -211,6 +214,8 @@ class Generator:
elif self.unsupported_level == ErrorLevel.RAISE and self.unsupported_messages:
raise UnsupportedError(concat_messages(self.unsupported_messages, self.max_unsupported))
+ if self.pretty:
+ sql = sql.replace(self.SENTINEL_LINE_BREAK, "\n")
return sql
def unsupported(self, message: str) -> None:
@@ -401,7 +406,17 @@ class Generator:
def generatedasidentitycolumnconstraint_sql(
self, expression: exp.GeneratedAsIdentityColumnConstraint
) -> str:
- return f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY"
+ start = expression.args.get("start")
+ start = f"START WITH {start}" if start else ""
+ increment = expression.args.get("increment")
+ increment = f"INCREMENT BY {increment}" if increment else ""
+ sequence_opts = ""
+ if start or increment:
+ sequence_opts = f"{start} {increment}"
+ sequence_opts = f" ({sequence_opts.strip()})"
+ return (
+ f"GENERATED {'ALWAYS' if expression.this else 'BY DEFAULT'} AS IDENTITY{sequence_opts}"
+ )
def notnullcolumnconstraint_sql(self, expression: exp.NotNullColumnConstraint) -> str:
return f"{'' if expression.args.get('allow_null') else 'NOT '}NULL"
@@ -475,10 +490,13 @@ class Generator:
materialized,
)
)
+ no_schema_binding = (
+ " WITH NO SCHEMA BINDING" if expression.args.get("no_schema_binding") else ""
+ )
post_expression_modifiers = "".join((data, statistics, no_primary_index))
- expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}"
+ expression_sql = f"CREATE{modifiers} {kind}{exists_sql} {this}{properties}{expression_sql}{post_expression_modifiers}{index_sql}{no_schema_binding}"
return self.prepend_ctes(expression, expression_sql)
def describe_sql(self, expression: exp.Describe) -> str:
@@ -517,13 +535,19 @@ class Generator:
type_sql = self.TYPE_MAPPING.get(type_value, type_value.value)
nested = ""
interior = self.expressions(expression, flat=True)
+ values = ""
if interior:
- nested = (
- f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
- if expression.args.get("nested")
- else f"({interior})"
- )
- return f"{type_sql}{nested}"
+ if expression.args.get("nested"):
+ nested = f"{self.STRUCT_DELIMITER[0]}{interior}{self.STRUCT_DELIMITER[1]}"
+ if expression.args.get("values") is not None:
+ delimiters = ("[", "]") if type_value == exp.DataType.Type.ARRAY else ("(", ")")
+ values = (
+ f"{delimiters[0]}{self.expressions(expression, 'values')}{delimiters[1]}"
+ )
+ else:
+ nested = f"({interior})"
+
+ return f"{type_sql}{nested}{values}"
def directory_sql(self, expression: exp.Directory) -> str:
local = "LOCAL " if expression.args.get("local") else ""
@@ -622,10 +646,14 @@ class Generator:
return self.sep() + self.expressions(properties, indent=False, sep=" ")
return ""
- def properties(self, properties: exp.Properties, prefix: str = "", sep: str = ", ") -> str:
+ def properties(
+ self, properties: exp.Properties, prefix: str = "", sep: str = ", ", suffix: str = ""
+ ) -> str:
if properties.expressions:
expressions = self.expressions(properties, sep=sep, indent=False)
- return f"{prefix}{' ' if prefix else ''}{self.wrap(expressions)}"
+ return (
+ f"{prefix}{' ' if prefix and prefix != ' ' else ''}{self.wrap(expressions)}{suffix}"
+ )
return ""
def with_properties(self, properties: exp.Properties) -> str:
@@ -763,14 +791,15 @@ class Generator:
return self.prepend_ctes(expression, sql)
def values_sql(self, expression: exp.Values) -> str:
- alias = self.sql(expression, "alias")
args = self.expressions(expression)
- if not alias:
- return f"VALUES{self.seg('')}{args}"
- alias = f" AS {alias}" if alias else alias
- if self.WRAP_DERIVED_VALUES:
- return f"(VALUES{self.seg('')}{args}){alias}"
- return f"VALUES{self.seg('')}{args}{alias}"
+ alias = self.sql(expression, "alias")
+ values = f"VALUES{self.seg('')}{args}"
+ values = (
+ f"({values})"
+ if self.WRAP_DERIVED_VALUES and (alias or isinstance(expression.parent, exp.From))
+ else values
+ )
+ return f"{values} AS {alias}" if alias else values
def var_sql(self, expression: exp.Var) -> str:
return self.sql(expression, "this")
@@ -868,6 +897,8 @@ class Generator:
if self._replace_backslash:
text = text.replace("\\", "\\\\")
text = text.replace(self.quote_end, self._escaped_quote_end)
+ if self.pretty:
+ text = text.replace("\n", self.SENTINEL_LINE_BREAK)
text = f"{self.quote_start}{text}{self.quote_end}"
return text
@@ -1036,7 +1067,9 @@ class Generator:
alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else alias
ordinality = " WITH ORDINALITY" if expression.args.get("ordinality") else ""
- return f"UNNEST({args}){ordinality}{alias}"
+ offset = expression.args.get("offset")
+ offset = f" WITH OFFSET AS {self.sql(offset)}" if offset else ""
+ return f"UNNEST({args}){ordinality}{alias}{offset}"
def where_sql(self, expression: exp.Where) -> str:
this = self.indent(self.sql(expression, "this"))
@@ -1132,15 +1165,14 @@ class Generator:
return f"EXTRACT({this} FROM {expression_sql})"
def trim_sql(self, expression: exp.Trim) -> str:
- target = self.sql(expression, "this")
trim_type = self.sql(expression, "position")
if trim_type == "LEADING":
- return f"LTRIM({target})"
+ return f"{self.normalize_func('LTRIM')}({self.format_args(expression.this)})"
elif trim_type == "TRAILING":
- return f"RTRIM({target})"
+ return f"{self.normalize_func('RTRIM')}({self.format_args(expression.this)})"
else:
- return f"TRIM({target})"
+ return f"{self.normalize_func('TRIM')}({self.format_args(expression.this, expression.expression)})"
def concat_sql(self, expression: exp.Concat) -> str:
if len(expression.expressions) == 1:
@@ -1317,6 +1349,10 @@ class Generator:
return f"ALTER COLUMN {this} DROP DEFAULT"
+ def renametable_sql(self, expression: exp.RenameTable) -> str:
+ this = self.sql(expression, "this")
+ return f"RENAME TO {this}"
+
def altertable_sql(self, expression: exp.AlterTable) -> str:
actions = expression.args["actions"]
@@ -1326,7 +1362,7 @@ class Generator:
actions = self.expressions(expression, "actions", prefix="ADD COLUMNS ")
elif isinstance(actions[0], exp.Drop):
actions = self.expressions(expression, "actions")
- elif isinstance(actions[0], exp.AlterColumn):
+ elif isinstance(actions[0], (exp.AlterColumn, exp.RenameTable)):
actions = self.sql(actions[0])
else:
self.unsupported(f"Unsupported ALTER TABLE action {actions[0].__class__.__name__}")
diff --git a/sqlglot/optimizer/optimizer.py b/sqlglot/optimizer/optimizer.py
index 46b6b30..5258c2b 100644
--- a/sqlglot/optimizer/optimizer.py
+++ b/sqlglot/optimizer/optimizer.py
@@ -52,7 +52,10 @@ def optimize(expression, schema=None, db=None, catalog=None, rules=RULES, **kwar
If no schema is provided then the default schema defined at `sqlgot.schema` will be used
db (str): specify the default database, as might be set by a `USE DATABASE db` statement
catalog (str): specify the default catalog, as might be set by a `USE CATALOG c` statement
- rules (sequence): sequence of optimizer rules to use
+ rules (sequence): sequence of optimizer rules to use.
+ Many of the rules require tables and columns to be qualified.
+ Do not remove qualify_tables or qualify_columns from the sequence of rules unless you know
+ what you're doing!
**kwargs: If a rule has a keyword argument with a same name in **kwargs, it will be passed in.
Returns:
sqlglot.Expression: optimized expression
diff --git a/sqlglot/optimizer/qualify_columns.py b/sqlglot/optimizer/qualify_columns.py
index f4568c2..8da4e43 100644
--- a/sqlglot/optimizer/qualify_columns.py
+++ b/sqlglot/optimizer/qualify_columns.py
@@ -1,7 +1,7 @@
import itertools
from sqlglot import alias, exp
-from sqlglot.errors import OptimizeError
+from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.scope import Scope, traverse_scope
from sqlglot.schema import ensure_schema
@@ -382,7 +382,7 @@ class _Resolver:
try:
return self.schema.column_names(source, only_visible)
except Exception as e:
- raise OptimizeError(str(e)) from e
+ raise SchemaError(str(e)) from e
if isinstance(source, Scope) and isinstance(source.expression, exp.Values):
return source.expression.alias_column_names
diff --git a/sqlglot/parser.py b/sqlglot/parser.py
index bd95db8..c97b19a 100644
--- a/sqlglot/parser.py
+++ b/sqlglot/parser.py
@@ -107,6 +107,8 @@ class Parser(metaclass=_Parser):
TokenType.VARCHAR,
TokenType.NVARCHAR,
TokenType.TEXT,
+ TokenType.MEDIUMTEXT,
+ TokenType.LONGTEXT,
TokenType.BINARY,
TokenType.VARBINARY,
TokenType.JSON,
@@ -233,6 +235,7 @@ class Parser(metaclass=_Parser):
TokenType.UNPIVOT,
TokenType.PROPERTIES,
TokenType.PROCEDURE,
+ TokenType.VIEW,
TokenType.VOLATILE,
TokenType.WINDOW,
*SUBQUERY_PREDICATES,
@@ -252,6 +255,7 @@ class Parser(metaclass=_Parser):
TRIM_TYPES = {TokenType.LEADING, TokenType.TRAILING, TokenType.BOTH}
FUNC_TOKENS = {
+ TokenType.COMMAND,
TokenType.CURRENT_DATE,
TokenType.CURRENT_DATETIME,
TokenType.CURRENT_TIMESTAMP,
@@ -552,7 +556,7 @@ class Parser(metaclass=_Parser):
TokenType.IF: lambda self: self._parse_if(),
}
- FUNCTION_PARSERS = {
+ FUNCTION_PARSERS: t.Dict[str, t.Callable] = {
"CONVERT": lambda self: self._parse_convert(self.STRICT_CAST),
"TRY_CONVERT": lambda self: self._parse_convert(False),
"EXTRACT": lambda self: self._parse_extract(),
@@ -937,6 +941,7 @@ class Parser(metaclass=_Parser):
statistics = None
no_primary_index = None
indexes = None
+ no_schema_binding = None
if create_token.token_type in (TokenType.FUNCTION, TokenType.PROCEDURE):
this = self._parse_user_defined_function()
@@ -975,6 +980,9 @@ class Parser(metaclass=_Parser):
break
else:
indexes.append(index)
+ elif create_token.token_type == TokenType.VIEW:
+ if self._match_text_seq("WITH", "NO", "SCHEMA", "BINDING"):
+ no_schema_binding = True
return self.expression(
exp.Create,
@@ -993,6 +1001,7 @@ class Parser(metaclass=_Parser):
statistics=statistics,
no_primary_index=no_primary_index,
indexes=indexes,
+ no_schema_binding=no_schema_binding,
)
def _parse_property(self) -> t.Optional[exp.Expression]:
@@ -1246,8 +1255,14 @@ class Parser(metaclass=_Parser):
return self.expression(exp.Partition, this=self._parse_wrapped_csv(parse_values))
def _parse_value(self) -> exp.Expression:
- expressions = self._parse_wrapped_csv(self._parse_conjunction)
- return self.expression(exp.Tuple, expressions=expressions)
+ if self._match(TokenType.L_PAREN):
+ expressions = self._parse_csv(self._parse_conjunction)
+ self._match_r_paren()
+ return self.expression(exp.Tuple, expressions=expressions)
+
+ # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
+ # Source: https://prestodb.io/docs/current/sql/values.html
+ return self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
def _parse_select(
self, nested: bool = False, table: bool = False, parse_subquery_alias: bool = True
@@ -1313,19 +1328,9 @@ class Parser(metaclass=_Parser):
# Union ALL should be a property of the top select node, not the subquery
return self._parse_subquery(this, parse_alias=parse_subquery_alias)
elif self._match(TokenType.VALUES):
- if self._curr.token_type == TokenType.L_PAREN:
- # We don't consume the left paren because it's consumed in _parse_value
- expressions = self._parse_csv(self._parse_value)
- else:
- # In presto we can have VALUES 1, 2 which results in 1 column & 2 rows.
- # Source: https://prestodb.io/docs/current/sql/values.html
- expressions = self._parse_csv(
- lambda: self.expression(exp.Tuple, expressions=[self._parse_conjunction()])
- )
-
this = self.expression(
exp.Values,
- expressions=expressions,
+ expressions=self._parse_csv(self._parse_value),
alias=self._parse_table_alias(),
)
else:
@@ -1612,13 +1617,12 @@ class Parser(metaclass=_Parser):
if alias:
this.set("alias", alias)
- if self._match(TokenType.WITH):
+ if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
this.set(
"hints",
- self._parse_wrapped_csv(
- lambda: self._parse_function() or self._parse_var(any_token=True)
- ),
+ self._parse_csv(lambda: self._parse_function() or self._parse_var(any_token=True)),
)
+ self._match_r_paren()
if not self.alias_post_tablesample:
table_sample = self._parse_table_sample()
@@ -1643,8 +1647,17 @@ class Parser(metaclass=_Parser):
alias.set("columns", [alias.this])
alias.set("this", None)
+ offset = None
+ if self._match_pair(TokenType.WITH, TokenType.OFFSET):
+ self._match(TokenType.ALIAS)
+ offset = self._parse_conjunction()
+
return self.expression(
- exp.Unnest, expressions=expressions, ordinality=ordinality, alias=alias
+ exp.Unnest,
+ expressions=expressions,
+ ordinality=ordinality,
+ alias=alias,
+ offset=offset,
)
def _parse_derived_table_values(self) -> t.Optional[exp.Expression]:
@@ -1999,7 +2012,7 @@ class Parser(metaclass=_Parser):
this = self._parse_column()
if type_token:
- if this:
+ if this and not isinstance(this, exp.Star):
return self.expression(exp.Cast, this=this, to=type_token)
if not type_token.args.get("expressions"):
self._retreat(index)
@@ -2050,6 +2063,7 @@ class Parser(metaclass=_Parser):
self._retreat(index)
return None
+ values: t.Optional[t.List[t.Optional[exp.Expression]]] = None
if nested and self._match(TokenType.LT):
if is_struct:
expressions = self._parse_csv(self._parse_struct_kwargs)
@@ -2059,6 +2073,10 @@ class Parser(metaclass=_Parser):
if not self._match(TokenType.GT):
self.raise_error("Expecting >")
+ if self._match_set((TokenType.L_BRACKET, TokenType.L_PAREN)):
+ values = self._parse_csv(self._parse_conjunction)
+ self._match_set((TokenType.R_BRACKET, TokenType.R_PAREN))
+
value: t.Optional[exp.Expression] = None
if type_token in self.TIMESTAMPS:
if self._match(TokenType.WITH_TIME_ZONE) or type_token == TokenType.TIMESTAMPTZ:
@@ -2097,9 +2115,13 @@ class Parser(metaclass=_Parser):
this=exp.DataType.Type[type_token.value.upper()],
expressions=expressions,
nested=nested,
+ values=values,
)
def _parse_struct_kwargs(self) -> t.Optional[exp.Expression]:
+ if self._curr and self._curr.token_type in self.TYPE_TOKENS:
+ return self._parse_types()
+
this = self._parse_id_var()
self._match(TokenType.COLON)
data_type = self._parse_types()
@@ -2412,6 +2434,14 @@ class Parser(metaclass=_Parser):
self._match(TokenType.ALWAYS)
kind = self.expression(exp.GeneratedAsIdentityColumnConstraint, this=True)
self._match_pair(TokenType.ALIAS, TokenType.IDENTITY)
+
+ if self._match(TokenType.L_PAREN):
+ if self._match_text_seq("START", "WITH"):
+ kind.set("start", self._parse_bitwise())
+ if self._match_text_seq("INCREMENT", "BY"):
+ kind.set("increment", self._parse_bitwise())
+
+ self._match_r_paren()
else:
return this
@@ -2619,8 +2649,12 @@ class Parser(metaclass=_Parser):
if self._match(TokenType.IN):
args.append(self._parse_bitwise())
- # Note: we're parsing in order needle, haystack, position
- this = exp.StrPosition.from_arg_list(args)
+ this = exp.StrPosition(
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
+ )
+
self.validate_expression(this, args)
return this
@@ -2999,6 +3033,8 @@ class Parser(metaclass=_Parser):
actions = self._parse_csv(self._parse_add_column)
elif self._match_text_seq("DROP", advance=False):
actions = self._parse_csv(self._parse_drop_column)
+ elif self._match_text_seq("RENAME", "TO"):
+ actions = self.expression(exp.RenameTable, this=self._parse_table(schema=True))
elif self._match_text_seq("ALTER"):
self._match(TokenType.COLUMN)
column = self._parse_field(any_token=True)
diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py
index 8e312a7..f12528f 100644
--- a/sqlglot/tokens.py
+++ b/sqlglot/tokens.py
@@ -82,6 +82,8 @@ class TokenType(AutoName):
VARCHAR = auto()
NVARCHAR = auto()
TEXT = auto()
+ MEDIUMTEXT = auto()
+ LONGTEXT = auto()
BINARY = auto()
VARBINARY = auto()
JSON = auto()
@@ -434,6 +436,8 @@ class Tokenizer(metaclass=_Tokenizer):
ESCAPES = ["'"]
+ _ESCAPES: t.Set[str] = set()
+
KEYWORDS = {
**{
f"{key}{postfix}": TokenType.BLOCK_START
@@ -461,6 +465,7 @@ class Tokenizer(metaclass=_Tokenizer):
"#>>": TokenType.DHASH_ARROW,
"<->": TokenType.LR_ARROW,
"ALL": TokenType.ALL,
+ "ALWAYS": TokenType.ALWAYS,
"AND": TokenType.AND,
"ANTI": TokenType.ANTI,
"ANY": TokenType.ANY,
@@ -472,6 +477,7 @@ class Tokenizer(metaclass=_Tokenizer):
"BETWEEN": TokenType.BETWEEN,
"BOTH": TokenType.BOTH,
"BUCKET": TokenType.BUCKET,
+ "BY DEFAULT": TokenType.BY_DEFAULT,
"CACHE": TokenType.CACHE,
"UNCACHE": TokenType.UNCACHE,
"CASE": TokenType.CASE,
@@ -521,9 +527,11 @@ class Tokenizer(metaclass=_Tokenizer):
"FOREIGN KEY": TokenType.FOREIGN_KEY,
"FORMAT": TokenType.FORMAT,
"FROM": TokenType.FROM,
+ "GENERATED": TokenType.GENERATED,
"GROUP BY": TokenType.GROUP_BY,
"GROUPING SETS": TokenType.GROUPING_SETS,
"HAVING": TokenType.HAVING,
+ "IDENTITY": TokenType.IDENTITY,
"IF": TokenType.IF,
"ILIKE": TokenType.ILIKE,
"IMMUTABLE": TokenType.IMMUTABLE,
@@ -746,7 +754,7 @@ class Tokenizer(metaclass=_Tokenizer):
)
def __init__(self) -> None:
- self._replace_backslash = "\\" in self._ESCAPES # type: ignore
+ self._replace_backslash = "\\" in self._ESCAPES
self.reset()
def reset(self) -> None:
@@ -771,7 +779,10 @@ class Tokenizer(metaclass=_Tokenizer):
self.reset()
self.sql = sql
self.size = len(sql)
+ self._scan()
+ return self.tokens
+ def _scan(self, until: t.Optional[t.Callable] = None) -> None:
while self.size and not self._end:
self._start = self._current
self._advance()
@@ -792,7 +803,9 @@ class Tokenizer(metaclass=_Tokenizer):
self._scan_identifier(identifier_end)
else:
self._scan_keywords()
- return self.tokens
+
+ if until and until():
+ break
def _chars(self, size: int) -> str:
if size == 1:
@@ -832,11 +845,13 @@ class Tokenizer(metaclass=_Tokenizer):
if token_type in self.COMMANDS and (
len(self.tokens) == 1 or self.tokens[-2].token_type == TokenType.SEMICOLON
):
- self._start = self._current
- while not self._end and self._peek != ";":
- self._advance()
- if self._start < self._current:
- self._add(TokenType.STRING)
+ start = self._current
+ tokens = len(self.tokens)
+ self._scan(lambda: self._peek == ";")
+ self.tokens = self.tokens[:tokens]
+ text = self.sql[start : self._current].strip()
+ if text:
+ self._add(TokenType.STRING, text)
def _scan_keywords(self) -> None:
size = 0
@@ -947,7 +962,8 @@ class Tokenizer(metaclass=_Tokenizer):
elif self._peek.isidentifier(): # type: ignore
number_text = self._text
literal = []
- while self._peek.isidentifier(): # type: ignore
+
+ while self._peek.strip() and self._peek not in self.SINGLE_TOKENS: # type: ignore
literal.append(self._peek.upper()) # type: ignore
self._advance()
@@ -1063,8 +1079,12 @@ class Tokenizer(metaclass=_Tokenizer):
delim_size = len(delimiter)
while True:
- if self._char in self._ESCAPES and self._peek == delimiter: # type: ignore
- text += delimiter
+ if (
+ self._char in self._ESCAPES
+ and self._peek
+ and (self._peek == delimiter or self._peek in self._ESCAPES)
+ ):
+ text += self._peek
self._advance(2)
else:
if self._chars(delim_size) == delimiter:
diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py
index c61a2f3..e5b1c94 100644
--- a/tests/dialects/test_bigquery.py
+++ b/tests/dialects/test_bigquery.py
@@ -6,6 +6,8 @@ class TestBigQuery(Validator):
dialect = "bigquery"
def test_bigquery(self):
+ self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
+ self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
self.validate_all(
"REGEXP_CONTAINS('foo', '.*')",
read={"bigquery": "REGEXP_CONTAINS('foo', '.*')"},
@@ -42,6 +44,15 @@ class TestBigQuery(Validator):
},
)
self.validate_all(
+ r"'\\'",
+ write={
+ "bigquery": r"'\\'",
+ "duckdb": r"'\'",
+ "presto": r"'\'",
+ "hive": r"'\\'",
+ },
+ )
+ self.validate_all(
R'R"""/\*.*\*/"""',
write={
"bigquery": R"'/\\*.*\\*/'",
diff --git a/tests/dialects/test_clickhouse.py b/tests/dialects/test_clickhouse.py
index 109e9f3..2827dd4 100644
--- a/tests/dialects/test_clickhouse.py
+++ b/tests/dialects/test_clickhouse.py
@@ -17,6 +17,7 @@ class TestClickhouse(Validator):
self.validate_identity("SELECT quantile(0.5)(a)")
self.validate_identity("SELECT quantiles(0.5)(a) AS x FROM t")
self.validate_identity("SELECT * FROM foo WHERE x GLOBAL IN (SELECT * FROM bar)")
+ self.validate_identity("position(a, b)")
self.validate_all(
"SELECT fname, lname, age FROM person ORDER BY age DESC NULLS FIRST, fname ASC NULLS LAST, lname",
@@ -47,3 +48,9 @@ class TestClickhouse(Validator):
"clickhouse": "SELECT quantileIf(0.5)(a, TRUE)",
},
)
+
+ def test_cte(self):
+ self.validate_identity("WITH 'x' AS foo SELECT foo")
+ self.validate_identity("WITH SUM(bytes) AS foo SELECT foo FROM system.parts")
+ self.validate_identity("WITH (SELECT foo) AS bar SELECT bar + 5")
+ self.validate_identity("WITH test1 AS (SELECT i + 1, j + 1 FROM test1) SELECT * FROM test1")
diff --git a/tests/dialects/test_dialect.py b/tests/dialects/test_dialect.py
index 284a30d..b2f4676 100644
--- a/tests/dialects/test_dialect.py
+++ b/tests/dialects/test_dialect.py
@@ -14,7 +14,7 @@ class Validator(unittest.TestCase):
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression
- def validate_all(self, sql, read=None, write=None, pretty=False):
+ def validate_all(self, sql, read=None, write=None, pretty=False, identify=False):
"""
Validate that:
1. Everything in `read` transpiles to `sql`
@@ -32,7 +32,10 @@ class Validator(unittest.TestCase):
with self.subTest(f"{read_dialect} -> {sql}"):
self.assertEqual(
parse_one(read_sql, read_dialect).sql(
- self.dialect, unsupported_level=ErrorLevel.IGNORE, pretty=pretty
+ self.dialect,
+ unsupported_level=ErrorLevel.IGNORE,
+ pretty=pretty,
+ identify=identify,
),
sql,
)
@@ -48,6 +51,7 @@ class Validator(unittest.TestCase):
write_dialect,
unsupported_level=ErrorLevel.IGNORE,
pretty=pretty,
+ identify=identify,
),
write_sql,
)
@@ -76,7 +80,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
- "redshift": "CAST(a AS TEXT)",
+ "redshift": "CAST(a AS VARCHAR(MAX))",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
@@ -155,7 +159,7 @@ class TestDialect(Validator):
"oracle": "CAST(a AS CLOB)",
"postgres": "CAST(a AS TEXT)",
"presto": "CAST(a AS VARCHAR)",
- "redshift": "CAST(a AS TEXT)",
+ "redshift": "CAST(a AS VARCHAR(MAX))",
"snowflake": "CAST(a AS TEXT)",
"spark": "CAST(a AS STRING)",
"starrocks": "CAST(a AS STRING)",
@@ -344,6 +348,7 @@ class TestDialect(Validator):
"duckdb": "CAST('2020-01-01' AS TIMESTAMP)",
"hive": "CAST('2020-01-01' AS TIMESTAMP)",
"presto": "CAST('2020-01-01' AS TIMESTAMP)",
+ "sqlite": "'2020-01-01'",
},
)
self.validate_all(
@@ -373,7 +378,7 @@ class TestDialect(Validator):
"duckdb": "CAST(x AS TEXT)",
"hive": "CAST(x AS STRING)",
"presto": "CAST(x AS VARCHAR)",
- "redshift": "CAST(x AS TEXT)",
+ "redshift": "CAST(x AS VARCHAR(MAX))",
},
)
self.validate_all(
@@ -488,7 +493,9 @@ class TestDialect(Validator):
"mysql": "DATE_ADD(x, INTERVAL 1 DAY)",
"postgres": "x + INTERVAL '1' 'day'",
"presto": "DATE_ADD('day', 1, x)",
+ "snowflake": "DATEADD(x, 1, 'day')",
"spark": "DATE_ADD(x, 1)",
+ "sqlite": "DATE(x, '1 day')",
"starrocks": "DATE_ADD(x, INTERVAL 1 DAY)",
"tsql": "DATEADD(day, 1, x)",
},
@@ -594,6 +601,7 @@ class TestDialect(Validator):
"hive": "TO_DATE(x)",
"presto": "CAST(DATE_PARSE(x, '%Y-%m-%d') AS DATE)",
"spark": "TO_DATE(x)",
+ "sqlite": "x",
},
)
self.validate_all(
@@ -955,7 +963,7 @@ class TestDialect(Validator):
},
)
self.validate_all(
- "STR_POSITION('a', x)",
+ "STR_POSITION(x, 'a')",
write={
"drill": "STRPOS(x, 'a')",
"duckdb": "STRPOS(x, 'a')",
@@ -971,7 +979,7 @@ class TestDialect(Validator):
"POSITION('a', x, 3)",
write={
"drill": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
- "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
+ "presto": "STRPOS(x, 'a', 3)",
"spark": "LOCATE('a', x, 3)",
"clickhouse": "position(x, 'a', 3)",
"snowflake": "POSITION('a', x, 3)",
@@ -982,9 +990,10 @@ class TestDialect(Validator):
"CONCAT_WS('-', 'a', 'b')",
write={
"duckdb": "CONCAT_WS('-', 'a', 'b')",
- "presto": "ARRAY_JOIN(ARRAY['a', 'b'], '-')",
+ "presto": "CONCAT_WS('-', 'a', 'b')",
"hive": "CONCAT_WS('-', 'a', 'b')",
"spark": "CONCAT_WS('-', 'a', 'b')",
+ "trino": "CONCAT_WS('-', 'a', 'b')",
},
)
@@ -992,9 +1001,10 @@ class TestDialect(Validator):
"CONCAT_WS('-', x)",
write={
"duckdb": "CONCAT_WS('-', x)",
- "presto": "ARRAY_JOIN(x, '-')",
"hive": "CONCAT_WS('-', x)",
+ "presto": "CONCAT_WS('-', x)",
"spark": "CONCAT_WS('-', x)",
+ "trino": "CONCAT_WS('-', x)",
},
)
self.validate_all(
@@ -1118,6 +1128,7 @@ class TestDialect(Validator):
self.validate_all(
"SELECT x FROM y OFFSET 10 FETCH FIRST 3 ROWS ONLY",
write={
+ "sqlite": "SELECT x FROM y LIMIT 3 OFFSET 10",
"oracle": "SELECT x FROM y OFFSET 10 ROWS FETCH FIRST 3 ROWS ONLY",
},
)
@@ -1197,7 +1208,7 @@ class TestDialect(Validator):
"oracle": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 CLOB, c2 CLOB(1024))",
"postgres": "CREATE TABLE t (b1 BYTEA, b2 BYTEA(1024), c1 TEXT, c2 TEXT(1024))",
"sqlite": "CREATE TABLE t (b1 BLOB, b2 BLOB(1024), c1 TEXT, c2 TEXT(1024))",
- "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 TEXT, c2 TEXT(1024))",
+ "redshift": "CREATE TABLE t (b1 VARBYTE, b2 VARBYTE(1024), c1 VARCHAR(MAX), c2 VARCHAR(1024))",
},
)
diff --git a/tests/dialects/test_hive.py b/tests/dialects/test_hive.py
index bbf00b1..d485593 100644
--- a/tests/dialects/test_hive.py
+++ b/tests/dialects/test_hive.py
@@ -357,6 +357,30 @@ class TestHive(Validator):
},
)
self.validate_all(
+ "SELECT 1a_1a FROM test_a",
+ write={
+ "spark": "SELECT 1a_1a FROM test_a",
+ },
+ )
+ self.validate_all(
+ "SELECT 1a AS 1a_1a FROM test_a",
+ write={
+ "spark": "SELECT 1a AS 1a_1a FROM test_a",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE test_table (1a STRING)",
+ write={
+ "spark": "CREATE TABLE test_table (1a STRING)",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE test_table2 (1a_1a STRING)",
+ write={
+ "spark": "CREATE TABLE test_table2 (1a_1a STRING)",
+ },
+ )
+ self.validate_all(
"PERCENTILE(x, 0.5)",
write={
"duckdb": "QUANTILE(x, 0.5)",
@@ -420,7 +444,7 @@ class TestHive(Validator):
"LOCATE('a', x, 3)",
write={
"duckdb": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
- "presto": "STRPOS(SUBSTR(x, 3), 'a') + 3 - 1",
+ "presto": "STRPOS(x, 'a', 3)",
"hive": "LOCATE('a', x, 3)",
"spark": "LOCATE('a', x, 3)",
},
diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py
index 7cd686d..dfd2f8e 100644
--- a/tests/dialects/test_mysql.py
+++ b/tests/dialects/test_mysql.py
@@ -65,6 +65,17 @@ class TestMySQL(Validator):
self.validate_identity("SET GLOBAL TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE")
self.validate_identity("SELECT SCHEMA()")
+ def test_types(self):
+ self.validate_all(
+ "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
+ read={
+ "mysql": "CAST(x AS MEDIUMTEXT) + CAST(y AS LONGTEXT)",
+ },
+ write={
+ "spark": "CAST(x AS TEXT) + CAST(y AS TEXT)",
+ },
+ )
+
def test_canonical_functions(self):
self.validate_identity("SELECT LEFT('str', 2)", "SELECT SUBSTRING('str', 1, 2)")
self.validate_identity("SELECT INSTR('str', 'substr')", "SELECT LOCATE('substr', 'str')")
diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py
index 583d349..2351e3b 100644
--- a/tests/dialects/test_postgres.py
+++ b/tests/dialects/test_postgres.py
@@ -46,14 +46,6 @@ class TestPostgres(Validator):
" CONSTRAINT valid_discount CHECK (price > discounted_price))"
},
)
- self.validate_all(
- "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)",
- write={"postgres": "CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)"},
- )
- self.validate_all(
- "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)",
- write={"postgres": "CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)"},
- )
with self.assertRaises(ParseError):
transpile("CREATE TABLE products (price DECIMAL CHECK price > 0)", read="postgres")
diff --git a/tests/dialects/test_presto.py b/tests/dialects/test_presto.py
index ee535e9..195e382 100644
--- a/tests/dialects/test_presto.py
+++ b/tests/dialects/test_presto.py
@@ -152,6 +152,10 @@ class TestPresto(Validator):
"spark": "FROM_UNIXTIME(x)",
},
)
+ self.validate_identity("FROM_UNIXTIME(a, b)")
+ self.validate_identity("FROM_UNIXTIME(a, b, c)")
+ self.validate_identity("TRIM(a, b)")
+ self.validate_identity("VAR_POP(a)")
self.validate_all(
"TO_UNIXTIME(x)",
write={
@@ -302,6 +306,7 @@ class TestPresto(Validator):
)
def test_presto(self):
+ self.validate_identity("SELECT BOOL_OR(a > 10) FROM asd AS T(a)")
self.validate_all(
'SELECT a."b" FROM "foo"',
write={
@@ -443,8 +448,10 @@ class TestPresto(Validator):
"spark": UnsupportedError,
},
)
+ self.validate_identity("SELECT * FROM (VALUES (1))")
self.validate_identity("START TRANSACTION READ WRITE, ISOLATION LEVEL SERIALIZABLE")
self.validate_identity("START TRANSACTION ISOLATION LEVEL REPEATABLE READ")
+ self.validate_identity("APPROX_PERCENTILE(a, b, c, d)")
def test_encode_decode(self):
self.validate_all(
@@ -460,6 +467,12 @@ class TestPresto(Validator):
},
)
self.validate_all(
+ "FROM_UTF8(x, y)",
+ write={
+ "presto": "FROM_UTF8(x, y)",
+ },
+ )
+ self.validate_all(
"ENCODE(x, 'utf-8')",
write={
"presto": "TO_UTF8(x)",
diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py
index f650c98..e20661e 100644
--- a/tests/dialects/test_redshift.py
+++ b/tests/dialects/test_redshift.py
@@ -89,7 +89,9 @@ class TestRedshift(Validator):
self.validate_identity(
"SELECT COUNT(*) FROM event WHERE eventname LIKE '%Ring%' OR eventname LIKE '%Die%'"
)
- self.validate_identity("CREATE TABLE SOUP DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL")
+ self.validate_identity(
+ "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL"
+ )
self.validate_identity(
"CREATE TABLE sales (salesid INTEGER NOT NULL) DISTKEY(listid) COMPOUND SORTKEY(listid, sellerid) DISTSTYLE AUTO"
)
@@ -102,3 +104,81 @@ class TestRedshift(Validator):
self.validate_identity(
"CREATE TABLE SOUP (SOUP1 VARCHAR(50) NOT NULL ENCODE ZSTD, SOUP2 VARCHAR(70) NULL ENCODE DELTA)"
)
+
+ def test_values(self):
+ self.validate_all(
+ "SELECT a, b FROM (VALUES (1, 2)) AS t (a, b)",
+ write={
+ "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b) AS t",
+ },
+ )
+ self.validate_all(
+ "SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
+ write={
+ "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
+ },
+ )
+ self.validate_all(
+ "SELECT a, b FROM (VALUES (1, 2), (3, 4), (5, 6), (7, 8)) AS t (a, b)",
+ write={
+ "redshift": "SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4 UNION ALL SELECT 5, 6 UNION ALL SELECT 7, 8) AS t",
+ },
+ )
+ self.validate_all(
+ "INSERT INTO t(a) VALUES (1), (2), (3)",
+ write={
+ "redshift": "INSERT INTO t (a) VALUES (1), (2), (3)",
+ },
+ )
+ self.validate_all(
+ "INSERT INTO t(a, b) SELECT a, b FROM (VALUES (1, 2), (3, 4)) AS t (a, b)",
+ write={
+ "redshift": "INSERT INTO t (a, b) SELECT a, b FROM (SELECT 1 AS a, 2 AS b UNION ALL SELECT 3, 4) AS t",
+ },
+ )
+ self.validate_all(
+ "INSERT INTO t(a, b) VALUES (1, 2), (3, 4)",
+ write={
+ "redshift": "INSERT INTO t (a, b) VALUES (1, 2), (3, 4)",
+ },
+ )
+
+ def test_create_table_like(self):
+ self.validate_all(
+ "CREATE TABLE t1 LIKE t2",
+ write={
+ "redshift": "CREATE TABLE t1 (LIKE t2)",
+ },
+ )
+ self.validate_all(
+ "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL",
+ write={
+ "redshift": "CREATE TABLE SOUP (LIKE other_table) DISTKEY(soup1) SORTKEY(soup2) DISTSTYLE ALL",
+ },
+ )
+
+ def test_rename_table(self):
+ self.validate_all(
+ "ALTER TABLE db.t1 RENAME TO db.t2",
+ write={
+ "spark": "ALTER TABLE db.t1 RENAME TO db.t2",
+ "redshift": "ALTER TABLE db.t1 RENAME TO t2",
+ },
+ )
+
+ def test_varchar_max(self):
+ self.validate_all(
+ "CREATE TABLE TEST (cola VARCHAR(MAX))",
+ write={
+ "redshift": 'CREATE TABLE "TEST" ("cola" VARCHAR(MAX))',
+ },
+ identify=True,
+ )
+
+ def test_no_schema_binding(self):
+ self.validate_all(
+ "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
+ write={
+ "redshift": "CREATE OR REPLACE VIEW v1 AS SELECT cola, colb FROM t1 WITH NO SCHEMA BINDING",
+ },
+ )
diff --git a/tests/dialects/test_spark.py b/tests/dialects/test_spark.py
index f287a89..fad858c 100644
--- a/tests/dialects/test_spark.py
+++ b/tests/dialects/test_spark.py
@@ -307,5 +307,12 @@ TBLPROPERTIES (
def test_iif(self):
self.validate_all(
- "SELECT IIF(cond, 'True', 'False')", write={"spark": "SELECT IF(cond, 'True', 'False')"}
+ "SELECT IIF(cond, 'True', 'False')",
+ write={"spark": "SELECT IF(cond, 'True', 'False')"},
+ )
+
+ def test_bool_or(self):
+ self.validate_all(
+ "SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
+ write={"duckdb": "SELECT a, BOOL_OR(b) FROM table GROUP BY a"},
)
diff --git a/tests/dialects/test_teradata.py b/tests/dialects/test_teradata.py
new file mode 100644
index 0000000..e56de25
--- /dev/null
+++ b/tests/dialects/test_teradata.py
@@ -0,0 +1,23 @@
+from tests.dialects.test_dialect import Validator
+
+
+class TestTeradata(Validator):
+ dialect = "teradata"
+
+ def test_translate(self):
+ self.validate_all(
+ "TRANSLATE(x USING LATIN_TO_UNICODE)",
+ write={
+ "teradata": "CAST(x AS CHAR CHARACTER SET UNICODE)",
+ },
+ )
+ self.validate_identity("CAST(x AS CHAR CHARACTER SET UNICODE)")
+
+ def test_update(self):
+ self.validate_all(
+ "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
+ write={
+ "teradata": "UPDATE A FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B SET col2 = '' WHERE A.col1 = B.col1",
+ "mysql": "UPDATE A SET col2 = '' FROM schema.tableA AS A, (SELECT col1 FROM schema.tableA GROUP BY col1) AS B WHERE A.col1 = B.col1",
+ },
+ )
diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py
index b74c05f..d2972ca 100644
--- a/tests/dialects/test_tsql.py
+++ b/tests/dialects/test_tsql.py
@@ -5,6 +5,13 @@ class TestTSQL(Validator):
dialect = "tsql"
def test_tsql(self):
+ self.validate_identity("DECLARE @TestVariable AS VARCHAR(100)='Save Our Planet'")
+ self.validate_identity("PRINT @TestVariable")
+ self.validate_identity("SELECT Employee_ID, Department_ID FROM @MyTableVar")
+ self.validate_identity("INSERT INTO @TestTable VALUES (1, 'Value1', 12, 20)")
+ self.validate_identity(
+ "SELECT x FROM @MyTableVar AS m JOIN Employee ON m.EmployeeID = Employee.EmployeeID"
+ )
self.validate_identity('SELECT "x"."y" FROM foo')
self.validate_identity("SELECT * FROM #foo")
self.validate_identity("SELECT * FROM ##foo")
diff --git a/tests/fixtures/identity.sql b/tests/fixtures/identity.sql
index beb5703..4e21d2b 100644
--- a/tests/fixtures/identity.sql
+++ b/tests/fixtures/identity.sql
@@ -59,6 +59,8 @@ map.x
SELECT call.x
a.b.INT(1.234)
INT(x / 100)
+time * 100
+int * 100
x IN (-1, 1)
x IN ('a', 'a''a')
x IN ((1))
@@ -69,6 +71,11 @@ x IS TRUE
x IS FALSE
x IS TRUE IS TRUE
x LIKE y IS TRUE
+MAP()
+GREATEST(x)
+LEAST(y)
+MAX(a, b)
+MIN(a, b)
time
zone
ARRAY<TEXT>
@@ -133,6 +140,7 @@ x AT TIME ZONE 'UTC'
CAST('2025-11-20 00:00:00+00' AS TIMESTAMP) AT TIME ZONE 'Africa/Cairo'
SET x = 1
SET -v
+SET x = ';'
COMMIT
USE db
NOT 1
@@ -170,6 +178,7 @@ SELECT COUNT(DISTINCT a, b)
SELECT COUNT(DISTINCT a, b + 1)
SELECT SUM(DISTINCT x)
SELECT SUM(x IGNORE NULLS) AS x
+SELECT TRUNCATE(a, b)
SELECT ARRAY_AGG(DISTINCT x IGNORE NULLS ORDER BY a, b DESC LIMIT 10) AS x
SELECT ARRAY_AGG(STRUCT(x, x AS y) ORDER BY z DESC) AS x
SELECT LAST_VALUE(x IGNORE NULLS) OVER y AS x
@@ -622,7 +631,7 @@ SELECT 1 /* c1 */ + 2 /* c2 */ + 3 /* c3 */
SELECT 1 /* c1 */ + 2 /* c2 */, 3 /* c3 */
SELECT x FROM a.b.c /* x */, e.f.g /* x */
SELECT FOO(x /* c */) /* FOO */, b /* b */
-SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM VALUES (1 /* c4 */, "test" /* c5 */) /* c6 */
+SELECT FOO(x /* c1 */ + y /* c2 */ + BLA(5 /* c3 */)) FROM (VALUES (1 /* c4 */, "test" /* c5 */)) /* c6 */
SELECT a FROM x WHERE a COLLATE 'utf8_general_ci' = 'b'
SELECT x AS INTO FROM bla
SELECT * INTO newevent FROM event
@@ -643,3 +652,21 @@ ALTER TABLE integers ALTER COLUMN i DROP DEFAULT
ALTER TABLE mydataset.mytable DROP COLUMN A, DROP COLUMN IF EXISTS B
ALTER TABLE mydataset.mytable ADD COLUMN A TEXT, ADD COLUMN IF NOT EXISTS B INT
SELECT div.a FROM test_table AS div
+WITH view AS (SELECT 1 AS x) SELECT * FROM view
+CREATE TABLE asd AS SELECT asd FROM asd WITH NO DATA
+CREATE TABLE asd AS SELECT asd FROM asd WITH DATA
+ARRAY<STRUCT<INT, DOUBLE, ARRAY<INT>>>
+ARRAY<INT>[1, 2, 3]
+ARRAY<INT>[]
+STRUCT<x VARCHAR(10)>
+STRUCT<x VARCHAR(10)>("bla")
+STRUCT<VARCHAR(10)>("bla")
+STRUCT<INT>(5)
+STRUCT<DATE>("2011-05-05")
+STRUCT<x INT, y TEXT>(1, t.str_col)
+SELECT CAST(NULL AS ARRAY<INT>) IS NULL AS array_is_null
+CREATE TABLE products (x INT GENERATED BY DEFAULT AS IDENTITY)
+CREATE TABLE products (x INT GENERATED ALWAYS AS IDENTITY)
+CREATE TABLE IF NOT EXISTS customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (INCREMENT BY 1))
+CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10 INCREMENT BY 1))
+CREATE TABLE customer (pk BIGINT NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 10))
diff --git a/tests/fixtures/pretty.sql b/tests/fixtures/pretty.sql
index 067fe77..64806eb 100644
--- a/tests/fixtures/pretty.sql
+++ b/tests/fixtures/pretty.sql
@@ -322,3 +322,23 @@ SELECT
* /* multi
line
comment */;
+WITH table_data AS (
+ SELECT 'bob' AS name, ARRAY['banana', 'apple', 'orange'] AS fruit_basket
+)
+SELECT
+ name,
+ fruit,
+ basket_index
+FROM table_data
+CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET basket_index;
+WITH table_data AS (
+ SELECT
+ 'bob' AS name,
+ ARRAY('banana', 'apple', 'orange') AS fruit_basket
+)
+SELECT
+ name,
+ fruit,
+ basket_index
+FROM table_data
+CROSS JOIN UNNEST(fruit_basket) AS fruit WITH OFFSET AS basket_index;
diff --git a/tests/test_expressions.py b/tests/test_expressions.py
index 906e08c..9e5f988 100644
--- a/tests/test_expressions.py
+++ b/tests/test_expressions.py
@@ -624,6 +624,10 @@ FROM foo""",
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError):
exp.to_table(1)
+ empty_string = exp.to_table("")
+ self.assertEqual(empty_string.name, "")
+ self.assertIsNone(table_only.args.get("db"))
+ self.assertIsNone(table_only.args.get("catalog"))
def test_to_column(self):
column_only = exp.to_column("column_name")
@@ -715,3 +719,9 @@ FROM foo""",
self.assertEqual(exp.DataType.build("OBJECT").sql(), "OBJECT")
self.assertEqual(exp.DataType.build("NULL").sql(), "NULL")
self.assertEqual(exp.DataType.build("UNKNOWN").sql(), "UNKNOWN")
+
+ def test_rename_table(self):
+ self.assertEqual(
+ exp.rename_table("t1", "t2").sql(),
+ "ALTER TABLE t1 RENAME TO t2",
+ )
diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py
index 887f427..af21679 100644
--- a/tests/test_optimizer.py
+++ b/tests/test_optimizer.py
@@ -6,7 +6,7 @@ from pandas.testing import assert_frame_equal
import sqlglot
from sqlglot import exp, optimizer, parse_one
-from sqlglot.errors import OptimizeError
+from sqlglot.errors import OptimizeError, SchemaError
from sqlglot.optimizer.annotate_types import annotate_types
from sqlglot.optimizer.scope import build_scope, traverse_scope, walk_in_scope
from sqlglot.schema import MappingSchema
@@ -161,7 +161,7 @@ class TestOptimizer(unittest.TestCase):
def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql):
- with self.assertRaises(OptimizeError):
+ with self.assertRaises((OptimizeError, SchemaError)):
optimizer.qualify_columns.qualify_columns(parse_one(sql), schema=self.schema)
def test_lower_identities(self):
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 03b801b..dbde437 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -325,3 +325,9 @@ class TestParser(unittest.TestCase):
"Expected table name",
logger,
)
+
+ def test_rename_table(self):
+ self.assertEqual(
+ parse_one("ALTER TABLE foo RENAME TO bar").sql(),
+ "ALTER TABLE foo RENAME TO bar",
+ )
diff --git a/tests/test_transpile.py b/tests/test_transpile.py
index 3a7fea4..3e094f5 100644
--- a/tests/test_transpile.py
+++ b/tests/test_transpile.py
@@ -272,6 +272,11 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
"WITH A(filter) AS (VALUES 1, 2, 3) SELECT * FROM A WHERE filter >= 2",
"WITH A(filter) AS (VALUES (1), (2), (3)) SELECT * FROM A WHERE filter >= 2",
)
+ self.validate(
+ "SELECT BOOL_OR(a > 10) FROM (VALUES 1, 2, 15) AS T(a)",
+ "SELECT BOOL_OR(a > 10) FROM (VALUES (1), (2), (15)) AS T(a)",
+ write="presto",
+ )
def test_alter(self):
self.validate(
@@ -447,6 +452,9 @@ FROM bar /* comment 5 */, tbl /* comment 6 */""",
self.assertEqual(generated, pretty)
self.assertEqual(parse_one(sql), parse_one(pretty))
+ def test_pretty_line_breaks(self):
+ self.assertEqual(transpile("SELECT '1\n2'", pretty=True)[0], "SELECT\n '1\n2'")
+
@mock.patch("sqlglot.parser.logger")
def test_error_level(self, logger):
invalid = "x + 1. ("