diff options
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r-- | sqlglot/dialects/tsql.py | 83 |
1 files changed, 80 insertions, 3 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index b77c2c0..01d5001 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim if isinstance(expression, exp.NumberToStr) else exp.Literal.string( format_time( - expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING) + expression.text("format"), + t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING), ) ) ) @@ -314,7 +315,9 @@ class TSQL(Dialect): FUNCTIONS = { **parser.Parser.FUNCTIONS, "CHARINDEX": lambda args: exp.StrPosition( - this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) + this=seq_get(args, 1), + substr=seq_get(args, 0), + position=seq_get(args, 2), ), "DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), @@ -365,6 +368,55 @@ class TSQL(Dialect): CONCAT_NULL_OUTPUTS_STRING = True + def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback: + """Applies to SQL Server and Azure SQL Database + COMMIT [ { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable ] ] + [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ] + + ROLLBACK { TRAN | TRANSACTION } + [ transaction_name | @tran_name_variable + | savepoint_name | @savepoint_variable ] + """ + rollback = self._prev.token_type == TokenType.ROLLBACK + + self._match_texts({"TRAN", "TRANSACTION"}) + this = self._parse_id_var() + + if rollback: + return self.expression(exp.Rollback, this=this) + + durability = None + if self._match_pair(TokenType.WITH, TokenType.L_PAREN): + self._match_text_seq("DELAYED_DURABILITY") + self._match(TokenType.EQ) + + if self._match_text_seq("OFF"): + durability = False + else: + self._match(TokenType.ON) + durability = True + + self._match_r_paren() + + return self.expression(exp.Commit, this=this, durability=durability) + + def _parse_transaction(self) -> exp.Transaction | exp.Command: + """Applies to SQL Server and Azure SQL Database + BEGIN { TRAN | TRANSACTION } + [ { transaction_name | @tran_name_variable } + [ WITH MARK [ 'description' ] ] + ] + """ + if self._match_texts(("TRAN", "TRANSACTION")): + transaction = self.expression(exp.Transaction, this=self._parse_id_var()) + if self._match_text_seq("WITH", "MARK"): + transaction.set("mark", self._parse_string()) + + return transaction + + return self._parse_as_command(self._prev) + def _parse_system_time(self) -> t.Optional[exp.Expression]: if not self._match_text_seq("FOR", "SYSTEM_TIME"): return None @@ -496,7 +548,9 @@ class TSQL(Dialect): exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]), exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this), exp.SHA2: lambda self, e: self.func( - "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this + "HASHBYTES", + exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), + e.this, ), exp.TimeToStr: _format_sql, } @@ -539,3 +593,26 @@ class TSQL(Dialect): into = self.sql(expression, "into") into = self.seg(f"INTO {into}") if into else "" return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}" + + def transaction_sql(self, expression: exp.Transaction) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + mark = self.sql(expression, "mark") + mark = f" WITH MARK {mark}" if mark else "" + return f"BEGIN TRANSACTION{this}{mark}" + + def commit_sql(self, expression: exp.Commit) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + durability = expression.args.get("durability") + durability = ( + f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})" + if durability is not None + else "" + ) + return f"COMMIT TRANSACTION{this}{durability}" + + def rollback_sql(self, expression: exp.Rollback) -> str: + this = self.sql(expression, "this") + this = f" {this}" if this else "" + return f"ROLLBACK TRANSACTION{this}" |