summaryrefslogtreecommitdiffstats
path: root/sqlglot/dialects/tsql.py
diff options
context:
space:
mode:
Diffstat (limited to 'sqlglot/dialects/tsql.py')
-rw-r--r--sqlglot/dialects/tsql.py83
1 files changed, 80 insertions, 3 deletions
diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py
index b77c2c0..01d5001 100644
--- a/sqlglot/dialects/tsql.py
+++ b/sqlglot/dialects/tsql.py
@@ -138,7 +138,8 @@ def _format_sql(self: generator.Generator, expression: exp.NumberToStr | exp.Tim
if isinstance(expression, exp.NumberToStr)
else exp.Literal.string(
format_time(
- expression.text("format"), t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING)
+ expression.text("format"),
+ t.cast(t.Dict[str, str], TSQL.INVERSE_TIME_MAPPING),
)
)
)
@@ -314,7 +315,9 @@ class TSQL(Dialect):
FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"CHARINDEX": lambda args: exp.StrPosition(
- this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2)
+ this=seq_get(args, 1),
+ substr=seq_get(args, 0),
+ position=seq_get(args, 2),
),
"DATEADD": parse_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL),
"DATEDIFF": parse_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL),
@@ -365,6 +368,55 @@ class TSQL(Dialect):
CONCAT_NULL_OUTPUTS_STRING = True
+ def _parse_commit_or_rollback(self) -> exp.Commit | exp.Rollback:
+ """Applies to SQL Server and Azure SQL Database
+ COMMIT [ { TRAN | TRANSACTION }
+ [ transaction_name | @tran_name_variable ] ]
+ [ WITH ( DELAYED_DURABILITY = { OFF | ON } ) ]
+
+ ROLLBACK { TRAN | TRANSACTION }
+ [ transaction_name | @tran_name_variable
+ | savepoint_name | @savepoint_variable ]
+ """
+ rollback = self._prev.token_type == TokenType.ROLLBACK
+
+ self._match_texts({"TRAN", "TRANSACTION"})
+ this = self._parse_id_var()
+
+ if rollback:
+ return self.expression(exp.Rollback, this=this)
+
+ durability = None
+ if self._match_pair(TokenType.WITH, TokenType.L_PAREN):
+ self._match_text_seq("DELAYED_DURABILITY")
+ self._match(TokenType.EQ)
+
+ if self._match_text_seq("OFF"):
+ durability = False
+ else:
+ self._match(TokenType.ON)
+ durability = True
+
+ self._match_r_paren()
+
+ return self.expression(exp.Commit, this=this, durability=durability)
+
+ def _parse_transaction(self) -> exp.Transaction | exp.Command:
+ """Applies to SQL Server and Azure SQL Database
+ BEGIN { TRAN | TRANSACTION }
+ [ { transaction_name | @tran_name_variable }
+ [ WITH MARK [ 'description' ] ]
+ ]
+ """
+ if self._match_texts(("TRAN", "TRANSACTION")):
+ transaction = self.expression(exp.Transaction, this=self._parse_id_var())
+ if self._match_text_seq("WITH", "MARK"):
+ transaction.set("mark", self._parse_string())
+
+ return transaction
+
+ return self._parse_as_command(self._prev)
+
def _parse_system_time(self) -> t.Optional[exp.Expression]:
if not self._match_text_seq("FOR", "SYSTEM_TIME"):
return None
@@ -496,7 +548,9 @@ class TSQL(Dialect):
exp.Select: transforms.preprocess([transforms.eliminate_distinct_on]),
exp.SHA: lambda self, e: self.func("HASHBYTES", exp.Literal.string("SHA1"), e.this),
exp.SHA2: lambda self, e: self.func(
- "HASHBYTES", exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"), e.this
+ "HASHBYTES",
+ exp.Literal.string(f"SHA2_{e.args.get('length', 256)}"),
+ e.this,
),
exp.TimeToStr: _format_sql,
}
@@ -539,3 +593,26 @@ class TSQL(Dialect):
into = self.sql(expression, "into")
into = self.seg(f"INTO {into}") if into else ""
return f"{self.seg('OUTPUT')} {self.expressions(expression, flat=True)}{into}"
+
+ def transaction_sql(self, expression: exp.Transaction) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ mark = self.sql(expression, "mark")
+ mark = f" WITH MARK {mark}" if mark else ""
+ return f"BEGIN TRANSACTION{this}{mark}"
+
+ def commit_sql(self, expression: exp.Commit) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ durability = expression.args.get("durability")
+ durability = (
+ f" WITH (DELAYED_DURABILITY = {'ON' if durability else 'OFF'})"
+ if durability is not None
+ else ""
+ )
+ return f"COMMIT TRANSACTION{this}{durability}"
+
+ def rollback_sql(self, expression: exp.Rollback) -> str:
+ this = self.sql(expression, "this")
+ this = f" {this}" if this else ""
+ return f"ROLLBACK TRANSACTION{this}"