summaryrefslogtreecommitdiffstats
path: root/third_party/rust/uniffi_bindgen/src/bindings/swift/templates/Async.swift
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--third_party/rust/uniffi_bindgen/src/bindings/swift/templates/Async.swift100
1 files changed, 77 insertions, 23 deletions
diff --git a/third_party/rust/uniffi_bindgen/src/bindings/swift/templates/Async.swift b/third_party/rust/uniffi_bindgen/src/bindings/swift/templates/Async.swift
index 695208861d..e16f3108e1 100644
--- a/third_party/rust/uniffi_bindgen/src/bindings/swift/templates/Async.swift
+++ b/third_party/rust/uniffi_bindgen/src/bindings/swift/templates/Async.swift
@@ -1,11 +1,13 @@
private let UNIFFI_RUST_FUTURE_POLL_READY: Int8 = 0
private let UNIFFI_RUST_FUTURE_POLL_MAYBE_READY: Int8 = 1
+fileprivate let uniffiContinuationHandleMap = UniffiHandleMap<UnsafeContinuation<Int8, Never>>()
+
fileprivate func uniffiRustCallAsync<F, T>(
- rustFutureFunc: () -> UnsafeMutableRawPointer,
- pollFunc: (UnsafeMutableRawPointer, UnsafeMutableRawPointer) -> (),
- completeFunc: (UnsafeMutableRawPointer, UnsafeMutablePointer<RustCallStatus>) -> F,
- freeFunc: (UnsafeMutableRawPointer) -> (),
+ rustFutureFunc: () -> UInt64,
+ pollFunc: (UInt64, @escaping UniffiRustFutureContinuationCallback, UInt64) -> (),
+ completeFunc: (UInt64, UnsafeMutablePointer<RustCallStatus>) -> F,
+ freeFunc: (UInt64) -> (),
liftFunc: (F) throws -> T,
errorHandler: ((RustBuffer) throws -> Error)?
) async throws -> T {
@@ -19,7 +21,11 @@ fileprivate func uniffiRustCallAsync<F, T>(
var pollResult: Int8;
repeat {
pollResult = await withUnsafeContinuation {
- pollFunc(rustFuture, ContinuationHolder($0).toOpaque())
+ pollFunc(
+ rustFuture,
+ uniffiFutureContinuationCallback,
+ uniffiContinuationHandleMap.insert(obj: $0)
+ )
}
} while pollResult != UNIFFI_RUST_FUTURE_POLL_READY
@@ -31,32 +37,80 @@ fileprivate func uniffiRustCallAsync<F, T>(
// Callback handlers for an async calls. These are invoked by Rust when the future is ready. They
// lift the return value or error and resume the suspended function.
-fileprivate func uniffiFutureContinuationCallback(ptr: UnsafeMutableRawPointer, pollResult: Int8) {
- ContinuationHolder.fromOpaque(ptr).resume(pollResult)
+fileprivate func uniffiFutureContinuationCallback(handle: UInt64, pollResult: Int8) {
+ if let continuation = try? uniffiContinuationHandleMap.remove(handle: handle) {
+ continuation.resume(returning: pollResult)
+ } else {
+ print("uniffiFutureContinuationCallback invalid handle")
+ }
}
-// Wraps UnsafeContinuation in a class so that we can use reference counting when passing it across
-// the FFI
-fileprivate class ContinuationHolder {
- let continuation: UnsafeContinuation<Int8, Never>
-
- init(_ continuation: UnsafeContinuation<Int8, Never>) {
- self.continuation = continuation
+{%- if ci.has_async_callback_interface_definition() %}
+private func uniffiTraitInterfaceCallAsync<T>(
+ makeCall: @escaping () async throws -> T,
+ handleSuccess: @escaping (T) -> (),
+ handleError: @escaping (Int8, RustBuffer) -> ()
+) -> UniffiForeignFuture {
+ let task = Task {
+ do {
+ handleSuccess(try await makeCall())
+ } catch {
+ handleError(CALL_UNEXPECTED_ERROR, {{ Type::String.borrow()|lower_fn }}(String(describing: error)))
+ }
}
+ let handle = UNIFFI_FOREIGN_FUTURE_HANDLE_MAP.insert(obj: task)
+ return UniffiForeignFuture(handle: handle, free: uniffiForeignFutureFree)
- func resume(_ pollResult: Int8) {
- self.continuation.resume(returning: pollResult)
- }
+}
- func toOpaque() -> UnsafeMutableRawPointer {
- return Unmanaged<ContinuationHolder>.passRetained(self).toOpaque()
+private func uniffiTraitInterfaceCallAsyncWithError<T, E>(
+ makeCall: @escaping () async throws -> T,
+ handleSuccess: @escaping (T) -> (),
+ handleError: @escaping (Int8, RustBuffer) -> (),
+ lowerError: @escaping (E) -> RustBuffer
+) -> UniffiForeignFuture {
+ let task = Task {
+ do {
+ handleSuccess(try await makeCall())
+ } catch let error as E {
+ handleError(CALL_ERROR, lowerError(error))
+ } catch {
+ handleError(CALL_UNEXPECTED_ERROR, {{ Type::String.borrow()|lower_fn }}(String(describing: error)))
+ }
}
+ let handle = UNIFFI_FOREIGN_FUTURE_HANDLE_MAP.insert(obj: task)
+ return UniffiForeignFuture(handle: handle, free: uniffiForeignFutureFree)
+}
+
+// Borrow the callback handle map implementation to store foreign future handles
+// TODO: consolidate the handle-map code (https://github.com/mozilla/uniffi-rs/pull/1823)
+fileprivate var UNIFFI_FOREIGN_FUTURE_HANDLE_MAP = UniffiHandleMap<UniffiForeignFutureTask>()
+
+// Protocol for tasks that handle foreign futures.
+//
+// Defining a protocol allows all tasks to be stored in the same handle map. This can't be done
+// with the task object itself, since has generic parameters.
+protocol UniffiForeignFutureTask {
+ func cancel()
+}
+
+extension Task: UniffiForeignFutureTask {}
- static func fromOpaque(_ ptr: UnsafeRawPointer) -> ContinuationHolder {
- return Unmanaged<ContinuationHolder>.fromOpaque(ptr).takeRetainedValue()
+private func uniffiForeignFutureFree(handle: UInt64) {
+ do {
+ let task = try UNIFFI_FOREIGN_FUTURE_HANDLE_MAP.remove(handle: handle)
+ // Set the cancellation flag on the task. If it's still running, the code can check the
+ // cancellation flag or call `Task.checkCancellation()`. If the task has completed, this is
+ // a no-op.
+ task.cancel()
+ } catch {
+ print("uniffiForeignFutureFree: handle missing from handlemap")
}
}
-fileprivate func uniffiInitContinuationCallback() {
- {{ ci.ffi_rust_future_continuation_callback_set().name() }}(uniffiFutureContinuationCallback)
+// For testing
+public func uniffiForeignFutureHandleCount{{ ci.namespace()|class_name }}() -> Int {
+ UNIFFI_FOREIGN_FUTURE_HANDLE_MAP.count
}
+
+{%- endif %}