diff options
Diffstat (limited to 'sqlglot/dialects')
-rw-r--r-- | sqlglot/dialects/dialect.py | 11 | ||||
-rw-r--r-- | sqlglot/dialects/duckdb.py | 19 | ||||
-rw-r--r-- | sqlglot/dialects/hive.py | 30 | ||||
-rw-r--r-- | sqlglot/dialects/mysql.py | 6 | ||||
-rw-r--r-- | sqlglot/dialects/oracle.py | 10 | ||||
-rw-r--r-- | sqlglot/dialects/presto.py | 5 | ||||
-rw-r--r-- | sqlglot/dialects/snowflake.py | 25 | ||||
-rw-r--r-- | sqlglot/dialects/tsql.py | 7 |
8 files changed, 80 insertions, 33 deletions
diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4fc93bf..5376dff 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -620,7 +620,16 @@ def concat_to_dpipe_sql(self: Generator, expression: exp.Concat | exp.SafeConcat return self.sql(this) -# Spark, DuckDB use (almost) the same naming scheme for the output columns of the PIVOT operator +def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: + bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) + if bad_args: + self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") + + return self.func( + "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") + ) + + def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: names = [] for agg in aggregations: diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index d7e5a43..1d8a7fb 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -15,6 +15,7 @@ from sqlglot.dialects.dialect import ( no_properties_sql, no_safe_divide_sql, pivot_column_names, + regexp_extract_sql, rename_func, str_position_sql, str_to_time_sql, @@ -88,19 +89,6 @@ def _datatype_sql(self: generator.Generator, expression: exp.DataType) -> str: return self.datatype_sql(expression) -def _regexp_extract_sql(self: generator.Generator, expression: exp.RegexpExtract) -> str: - bad_args = list(filter(expression.args.get, ("position", "occurrence"))) - if bad_args: - self.unsupported(f"REGEXP_EXTRACT does not support arg(s) {bad_args}") - - return self.func( - "REGEXP_EXTRACT", - expression.args.get("this"), - expression.args.get("expression"), - expression.args.get("group"), - ) - - def _json_format_sql(self: generator.Generator, expression: exp.JSONFormat) -> str: sql = self.func("TO_JSON", expression.this, expression.args.get("options")) return f"CAST({sql} AS TEXT)" @@ -156,6 +144,9 @@ class DuckDB(Dialect): "LIST_REVERSE_SORT": _sort_array_reverse, "LIST_SORT": exp.SortArray.from_arg_list, "LIST_VALUE": exp.Array.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "REGEXP_MATCHES": exp.RegexpLike.from_arg_list, "STRFTIME": format_time_lambda(exp.TimeToStr, "duckdb"), "STRING_SPLIT": exp.Split.from_arg_list, @@ -227,7 +218,7 @@ class DuckDB(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.LogicalAnd: rename_func("BOOL_AND"), exp.Properties: no_properties_sql, - exp.RegexpExtract: _regexp_extract_sql, + exp.RegexpExtract: regexp_extract_sql, exp.RegexpLike: rename_func("REGEXP_MATCHES"), exp.RegexpSplit: rename_func("STR_SPLIT_REGEX"), exp.SafeDivide: no_safe_divide_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index 5762efb..f968f6a 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -17,6 +17,7 @@ from sqlglot.dialects.dialect import ( no_recursive_cte_sql, no_safe_divide_sql, no_trycast_sql, + regexp_extract_sql, rename_func, right_to_substring_sql, strposition_to_locate_sql, @@ -230,23 +231,24 @@ class Hive(Dialect): **parser.Parser.FUNCTIONS, "BASE64": exp.ToBase64.from_arg_list, "COLLECT_LIST": exp.ArrayAgg.from_arg_list, + "COLLECT_SET": exp.SetAgg.from_arg_list, "DATE_ADD": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=seq_get(args, 1), unit=exp.Literal.string("DAY") ), - "DATEDIFF": lambda args: exp.DateDiff( - this=exp.TsOrDsToDate(this=seq_get(args, 0)), - expression=exp.TsOrDsToDate(this=seq_get(args, 1)), + "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( + [ + exp.TimeStrToTime(this=seq_get(args, 0)), + seq_get(args, 1), + ] ), "DATE_SUB": lambda args: exp.TsOrDsAdd( this=seq_get(args, 0), expression=exp.Mul(this=seq_get(args, 1), expression=exp.Literal.number(-1)), unit=exp.Literal.string("DAY"), ), - "DATE_FORMAT": lambda args: format_time_lambda(exp.TimeToStr, "hive")( - [ - exp.TimeStrToTime(this=seq_get(args, 0)), - seq_get(args, 1), - ] + "DATEDIFF": lambda args: exp.DateDiff( + this=exp.TsOrDsToDate(this=seq_get(args, 0)), + expression=exp.TsOrDsToDate(this=seq_get(args, 1)), ), "DAY": lambda args: exp.Day(this=exp.TsOrDsToDate(this=seq_get(args, 0))), "FROM_UNIXTIME": format_time_lambda(exp.UnixToStr, "hive", True), @@ -256,7 +258,9 @@ class Hive(Dialect): "MONTH": lambda args: exp.Month(this=exp.TsOrDsToDate.from_arg_list(args)), "PERCENTILE": exp.Quantile.from_arg_list, "PERCENTILE_APPROX": exp.ApproxQuantile.from_arg_list, - "COLLECT_SET": exp.SetAgg.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "SIZE": exp.ArraySize.from_arg_list, "SPLIT": exp.RegexpSplit.from_arg_list, "TO_DATE": format_time_lambda(exp.TsOrDsToDate, "hive"), @@ -363,6 +367,7 @@ class Hive(Dialect): exp.Create: create_with_partitions_sql, exp.Quantile: rename_func("PERCENTILE"), exp.ApproxQuantile: rename_func("PERCENTILE_APPROX"), + exp.RegexpExtract: regexp_extract_sql, exp.RegexpLike: lambda self, e: self.binary(e, "RLIKE"), exp.RegexpSplit: rename_func("SPLIT"), exp.Right: right_to_substring_sql, @@ -422,5 +427,12 @@ class Hive(Dialect): expression = exp.DataType.build("text") elif expression.this in exp.DataType.TEMPORAL_TYPES: expression = exp.DataType.build(expression.this) + elif expression.is_type("float"): + size_expression = expression.find(exp.DataTypeSize) + if size_expression: + size = int(size_expression.name) + expression = ( + exp.DataType.build("float") if size <= 32 else exp.DataType.build("double") + ) return super().datatype_sql(expression) diff --git a/sqlglot/dialects/mysql.py b/sqlglot/dialects/mysql.py index bae0e50..e4de934 100644 --- a/sqlglot/dialects/mysql.py +++ b/sqlglot/dialects/mysql.py @@ -193,6 +193,12 @@ class MySQL(Dialect): TokenType.VALUES, } + CONJUNCTION = { + **parser.Parser.CONJUNCTION, + TokenType.DAMP: exp.And, + TokenType.XOR: exp.Xor, + } + TABLE_ALIAS_TOKENS = ( parser.Parser.TABLE_ALIAS_TOKENS - parser.Parser.TABLE_INDEX_HINT_TOKENS ) diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 2b77ef9..69da133 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -99,6 +99,9 @@ class Oracle(Dialect): LOCKING_READS_SUPPORTED = True JOIN_HINTS = False TABLE_HINTS = False + COLUMN_JOIN_MARKS_SUPPORTED = True + + LIMIT_FETCH = "FETCH" TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -110,6 +113,7 @@ class Oracle(Dialect): exp.DataType.Type.DOUBLE: "DOUBLE PRECISION", exp.DataType.Type.VARCHAR: "VARCHAR2", exp.DataType.Type.NVARCHAR: "NVARCHAR2", + exp.DataType.Type.NCHAR: "NCHAR", exp.DataType.Type.TEXT: "CLOB", exp.DataType.Type.BINARY: "BLOB", exp.DataType.Type.VARBINARY: "BLOB", @@ -140,15 +144,9 @@ class Oracle(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } - LIMIT_FETCH = "FETCH" - def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" - def column_sql(self, expression: exp.Column) -> str: - column = super().column_sql(expression) - return f"{column} (+)" if expression.args.get("join_mark") else column - def xmltable_sql(self, expression: exp.XMLTable) -> str: this = self.sql(expression, "this") passing = self.expressions(expression, key="passing") diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 1721588..7d35c67 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -12,6 +12,7 @@ from sqlglot.dialects.dialect import ( no_ilike_sql, no_pivot_sql, no_safe_divide_sql, + regexp_extract_sql, rename_func, right_to_substring_sql, struct_extract_sql, @@ -215,6 +216,9 @@ class Presto(Dialect): this=seq_get(args, 0), replace=seq_get(args, 1), charset=exp.Literal.string("utf-8") ), "NOW": exp.CurrentTimestamp.from_arg_list, + "REGEXP_EXTRACT": lambda args: exp.RegexpExtract( + this=seq_get(args, 0), expression=seq_get(args, 1), group=seq_get(args, 2) + ), "SEQUENCE": exp.GenerateSeries.from_arg_list, "STRPOS": lambda args: exp.StrPosition( this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2) @@ -293,6 +297,7 @@ class Presto(Dialect): exp.LogicalOr: rename_func("BOOL_OR"), exp.Pivot: no_pivot_sql, exp.Quantile: _quantile_sql, + exp.RegexpExtract: regexp_extract_sql, exp.Right: right_to_substring_sql, exp.SafeBracket: lambda self, e: self.func( "ELEMENT_AT", e.this, seq_get(apply_index_offset(e.this, e.expressions, 1), 0) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 19924cd..715a84c 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -223,13 +223,14 @@ class Snowflake(Dialect): "IFF": exp.If.from_arg_list, "NULLIFZERO": _nullifzero_to_if, "OBJECT_CONSTRUCT": _parse_object_construct, + "REGEXP_SUBSTR": exp.RegexpExtract.from_arg_list, "RLIKE": exp.RegexpLike.from_arg_list, "SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)), "TIMEDIFF": _parse_datediff, "TIMESTAMPDIFF": _parse_datediff, "TO_ARRAY": exp.Array.from_arg_list, - "TO_VARCHAR": exp.ToChar.from_arg_list, "TO_TIMESTAMP": _snowflake_to_timestamp, + "TO_VARCHAR": exp.ToChar.from_arg_list, "ZEROIFNULL": _zeroifnull_to_if, } @@ -361,12 +362,12 @@ class Snowflake(Dialect): "OBJECT_CONSTRUCT", *(arg for expression in e.expressions for arg in expression.flatten()), ), + exp.TimestampTrunc: timestamptrunc_sql, exp.TimeStrToTime: timestrtotime_sql, - exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.TimeToStr: lambda self, e: self.func( "TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e) ), - exp.TimestampTrunc: timestamptrunc_sql, + exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})", exp.ToChar: lambda self, e: self.function_fallback_sql(e), exp.Trim: lambda self, e: self.func("TRIM", e.this, e.expression), exp.TsOrDsToDate: ts_or_ds_to_date_sql("snowflake"), @@ -390,6 +391,24 @@ class Snowflake(Dialect): exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED, } + def regexpextract_sql(self, expression: exp.RegexpExtract) -> str: + # Other dialects don't support all of the following parameters, so we need to + # generate default values as necessary to ensure the transpilation is correct + group = expression.args.get("group") + parameters = expression.args.get("parameters") or (group and exp.Literal.string("c")) + occurrence = expression.args.get("occurrence") or (parameters and exp.Literal.number(1)) + position = expression.args.get("position") or (occurrence and exp.Literal.number(1)) + + return self.func( + "REGEXP_SUBSTR", + expression.this, + expression.expression, + position, + occurrence, + parameters, + group, + ) + def except_op(self, expression: exp.Except) -> str: if not expression.args.get("distinct", False): self.unsupported("EXCEPT with All is not supported in Snowflake") diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index 92bb755..b77c2c0 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -302,6 +302,7 @@ class TSQL(Dialect): "UNIQUEIDENTIFIER": TokenType.UNIQUEIDENTIFIER, "VARCHAR(MAX)": TokenType.TEXT, "XML": TokenType.XML, + "OUTPUT": TokenType.RETURNING, "SYSTEM_USER": TokenType.CURRENT_USER, } @@ -469,6 +470,7 @@ class TSQL(Dialect): LOCKING_READS_SUPPORTED = True LIMIT_IS_TOP = True QUERY_HINTS = False + RETURNING_END = False TYPE_MAPPING = { **generator.Generator.TYPE_MAPPING, @@ -532,3 +534,8 @@ class TSQL(Dialect): table = expression.args.get("table") table = f"{table} " if table else "" return f"RETURNS {table}{self.sql(expression, 'this')}" + + def returning_sql(self, expression: exp.Returning) -> str: + into = self.sql(expression, "into") + into = self.seg(f"INTO {into}") if into else "" + return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" |