diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-30 18:31:36 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-05-30 18:31:36 +0000 |
commit | e02c5b5930c2c9ba3e5423fe12e2ef0155017297 (patch) | |
tree | fd60ebbbb5299e16e5fca8c773ddb74f764760db /vendor/lsp-server/src | |
parent | Adding debian version 1.73.0+dfsg1-1. (diff) | |
download | rustc-e02c5b5930c2c9ba3e5423fe12e2ef0155017297.tar.xz rustc-e02c5b5930c2c9ba3e5423fe12e2ef0155017297.zip |
Merging upstream version 1.74.1+dfsg1.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'vendor/lsp-server/src')
-rw-r--r-- | vendor/lsp-server/src/error.rs | 50 | ||||
-rw-r--r-- | vendor/lsp-server/src/lib.rs | 284 | ||||
-rw-r--r-- | vendor/lsp-server/src/msg.rs | 351 | ||||
-rw-r--r-- | vendor/lsp-server/src/req_queue.rs | 69 | ||||
-rw-r--r-- | vendor/lsp-server/src/socket.rs | 46 | ||||
-rw-r--r-- | vendor/lsp-server/src/stdio.rs | 71 |
6 files changed, 871 insertions, 0 deletions
diff --git a/vendor/lsp-server/src/error.rs b/vendor/lsp-server/src/error.rs new file mode 100644 index 000000000..755b3fd95 --- /dev/null +++ b/vendor/lsp-server/src/error.rs @@ -0,0 +1,50 @@ +use std::fmt; + +use crate::{Notification, Request}; + +#[derive(Debug, Clone, PartialEq)] +pub struct ProtocolError(pub(crate) String); + +impl std::error::Error for ProtocolError {} + +impl fmt::Display for ProtocolError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(&self.0, f) + } +} + +#[derive(Debug)] +pub enum ExtractError<T> { + /// The extracted message was of a different method than expected. + MethodMismatch(T), + /// Failed to deserialize the message. + JsonError { method: String, error: serde_json::Error }, +} + +impl std::error::Error for ExtractError<Request> {} +impl fmt::Display for ExtractError<Request> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ExtractError::MethodMismatch(req) => { + write!(f, "Method mismatch for request '{}'", req.method) + } + ExtractError::JsonError { method, error } => { + write!(f, "Invalid request\nMethod: {method}\n error: {error}",) + } + } + } +} + +impl std::error::Error for ExtractError<Notification> {} +impl fmt::Display for ExtractError<Notification> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ExtractError::MethodMismatch(req) => { + write!(f, "Method mismatch for notification '{}'", req.method) + } + ExtractError::JsonError { method, error } => { + write!(f, "Invalid notification\nMethod: {method}\n error: {error}") + } + } + } +} diff --git a/vendor/lsp-server/src/lib.rs b/vendor/lsp-server/src/lib.rs new file mode 100644 index 000000000..affab60a2 --- /dev/null +++ b/vendor/lsp-server/src/lib.rs @@ -0,0 +1,284 @@ +//! A language server scaffold, exposing a synchronous crossbeam-channel based API. +//! This crate handles protocol handshaking and parsing messages, while you +//! control the message dispatch loop yourself. +//! +//! Run with `RUST_LOG=lsp_server=debug` to see all the messages. + +#![warn(rust_2018_idioms, unused_lifetimes, semicolon_in_expressions_from_macros)] + +mod msg; +mod stdio; +mod error; +mod socket; +mod req_queue; + +use std::{ + io, + net::{TcpListener, TcpStream, ToSocketAddrs}, +}; + +use crossbeam_channel::{Receiver, Sender}; + +pub use crate::{ + error::{ExtractError, ProtocolError}, + msg::{ErrorCode, Message, Notification, Request, RequestId, Response, ResponseError}, + req_queue::{Incoming, Outgoing, ReqQueue}, + stdio::IoThreads, +}; + +/// Connection is just a pair of channels of LSP messages. +pub struct Connection { + pub sender: Sender<Message>, + pub receiver: Receiver<Message>, +} + +impl Connection { + /// Create connection over standard in/standard out. + /// + /// Use this to create a real language server. + pub fn stdio() -> (Connection, IoThreads) { + let (sender, receiver, io_threads) = stdio::stdio_transport(); + (Connection { sender, receiver }, io_threads) + } + + /// Open a connection over tcp. + /// This call blocks until a connection is established. + /// + /// Use this to create a real language server. + pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<(Connection, IoThreads)> { + let stream = TcpStream::connect(addr)?; + let (sender, receiver, io_threads) = socket::socket_transport(stream); + Ok((Connection { sender, receiver }, io_threads)) + } + + /// Listen for a connection over tcp. + /// This call blocks until a connection is established. + /// + /// Use this to create a real language server. + pub fn listen<A: ToSocketAddrs>(addr: A) -> io::Result<(Connection, IoThreads)> { + let listener = TcpListener::bind(addr)?; + let (stream, _) = listener.accept()?; + let (sender, receiver, io_threads) = socket::socket_transport(stream); + Ok((Connection { sender, receiver }, io_threads)) + } + + /// Creates a pair of connected connections. + /// + /// Use this for testing. + pub fn memory() -> (Connection, Connection) { + let (s1, r1) = crossbeam_channel::unbounded(); + let (s2, r2) = crossbeam_channel::unbounded(); + (Connection { sender: s1, receiver: r2 }, Connection { sender: s2, receiver: r1 }) + } + + /// Starts the initialization process by waiting for an initialize + /// request from the client. Use this for more advanced customization than + /// `initialize` can provide. + /// + /// Returns the request id and serialized `InitializeParams` from the client. + /// + /// # Example + /// + /// ```no_run + /// use std::error::Error; + /// use lsp_types::{ClientCapabilities, InitializeParams, ServerCapabilities}; + /// + /// use lsp_server::{Connection, Message, Request, RequestId, Response}; + /// + /// fn main() -> Result<(), Box<dyn Error + Sync + Send>> { + /// // Create the transport. Includes the stdio (stdin and stdout) versions but this could + /// // also be implemented to use sockets or HTTP. + /// let (connection, io_threads) = Connection::stdio(); + /// + /// // Run the server + /// let (id, params) = connection.initialize_start()?; + /// + /// let init_params: InitializeParams = serde_json::from_value(params).unwrap(); + /// let client_capabilities: ClientCapabilities = init_params.capabilities; + /// let server_capabilities = ServerCapabilities::default(); + /// + /// let initialize_data = serde_json::json!({ + /// "capabilities": server_capabilities, + /// "serverInfo": { + /// "name": "lsp-server-test", + /// "version": "0.1" + /// } + /// }); + /// + /// connection.initialize_finish(id, initialize_data)?; + /// + /// // ... Run main loop ... + /// + /// Ok(()) + /// } + /// ``` + pub fn initialize_start(&self) -> Result<(RequestId, serde_json::Value), ProtocolError> { + loop { + break match self.receiver.recv() { + Ok(Message::Request(req)) if req.is_initialize() => Ok((req.id, req.params)), + // Respond to non-initialize requests with ServerNotInitialized + Ok(Message::Request(req)) => { + let resp = Response::new_err( + req.id.clone(), + ErrorCode::ServerNotInitialized as i32, + format!("expected initialize request, got {req:?}"), + ); + self.sender.send(resp.into()).unwrap(); + continue; + } + Ok(Message::Notification(n)) if !n.is_exit() => { + continue; + } + Ok(msg) => Err(ProtocolError(format!("expected initialize request, got {msg:?}"))), + Err(e) => { + Err(ProtocolError(format!("expected initialize request, got error: {e}"))) + } + }; + } + } + + /// Finishes the initialization process by sending an `InitializeResult` to the client + pub fn initialize_finish( + &self, + initialize_id: RequestId, + initialize_result: serde_json::Value, + ) -> Result<(), ProtocolError> { + let resp = Response::new_ok(initialize_id, initialize_result); + self.sender.send(resp.into()).unwrap(); + match &self.receiver.recv() { + Ok(Message::Notification(n)) if n.is_initialized() => Ok(()), + Ok(msg) => { + Err(ProtocolError(format!(r#"expected initialized notification, got: {msg:?}"#))) + } + Err(e) => { + Err(ProtocolError(format!("expected initialized notification, got error: {e}",))) + } + } + } + + /// Initialize the connection. Sends the server capabilities + /// to the client and returns the serialized client capabilities + /// on success. If more fine-grained initialization is required use + /// `initialize_start`/`initialize_finish`. + /// + /// # Example + /// + /// ```no_run + /// use std::error::Error; + /// use lsp_types::ServerCapabilities; + /// + /// use lsp_server::{Connection, Message, Request, RequestId, Response}; + /// + /// fn main() -> Result<(), Box<dyn Error + Sync + Send>> { + /// // Create the transport. Includes the stdio (stdin and stdout) versions but this could + /// // also be implemented to use sockets or HTTP. + /// let (connection, io_threads) = Connection::stdio(); + /// + /// // Run the server + /// let server_capabilities = serde_json::to_value(&ServerCapabilities::default()).unwrap(); + /// let initialization_params = connection.initialize(server_capabilities)?; + /// + /// // ... Run main loop ... + /// + /// Ok(()) + /// } + /// ``` + pub fn initialize( + &self, + server_capabilities: serde_json::Value, + ) -> Result<serde_json::Value, ProtocolError> { + let (id, params) = self.initialize_start()?; + + let initialize_data = serde_json::json!({ + "capabilities": server_capabilities, + }); + + self.initialize_finish(id, initialize_data)?; + + Ok(params) + } + + /// If `req` is `Shutdown`, respond to it and return `true`, otherwise return `false` + pub fn handle_shutdown(&self, req: &Request) -> Result<bool, ProtocolError> { + if !req.is_shutdown() { + return Ok(false); + } + let resp = Response::new_ok(req.id.clone(), ()); + let _ = self.sender.send(resp.into()); + match &self.receiver.recv_timeout(std::time::Duration::from_secs(30)) { + Ok(Message::Notification(n)) if n.is_exit() => (), + Ok(msg) => { + return Err(ProtocolError(format!("unexpected message during shutdown: {msg:?}"))) + } + Err(e) => return Err(ProtocolError(format!("unexpected error during shutdown: {e}"))), + } + Ok(true) + } +} + +#[cfg(test)] +mod tests { + use crossbeam_channel::unbounded; + use lsp_types::notification::{Exit, Initialized, Notification}; + use lsp_types::request::{Initialize, Request}; + use lsp_types::{InitializeParams, InitializedParams}; + use serde_json::to_value; + + use crate::{Connection, Message, ProtocolError, RequestId}; + + struct TestCase { + test_messages: Vec<Message>, + expected_resp: Result<(RequestId, serde_json::Value), ProtocolError>, + } + + fn initialize_start_test(test_case: TestCase) { + let (reader_sender, reader_receiver) = unbounded::<Message>(); + let (writer_sender, writer_receiver) = unbounded::<Message>(); + let conn = Connection { sender: writer_sender, receiver: reader_receiver }; + + for msg in test_case.test_messages { + assert!(reader_sender.send(msg).is_ok()); + } + + let resp = conn.initialize_start(); + assert_eq!(test_case.expected_resp, resp); + + assert!(writer_receiver.recv_timeout(std::time::Duration::from_secs(1)).is_err()); + } + + #[test] + fn not_exit_notification() { + let notification = crate::Notification { + method: Initialized::METHOD.to_string(), + params: to_value(InitializedParams {}).unwrap(), + }; + + let params_as_value = to_value(InitializeParams::default()).unwrap(); + let req_id = RequestId::from(234); + let request = crate::Request { + id: req_id.clone(), + method: Initialize::METHOD.to_string(), + params: params_as_value.clone(), + }; + + initialize_start_test(TestCase { + test_messages: vec![notification.into(), request.into()], + expected_resp: Ok((req_id, params_as_value)), + }); + } + + #[test] + fn exit_notification() { + let notification = + crate::Notification { method: Exit::METHOD.to_string(), params: to_value(()).unwrap() }; + let notification_msg = Message::from(notification); + + initialize_start_test(TestCase { + test_messages: vec![notification_msg.clone()], + expected_resp: Err(ProtocolError(format!( + "expected initialize request, got {:?}", + notification_msg + ))), + }); + } +} diff --git a/vendor/lsp-server/src/msg.rs b/vendor/lsp-server/src/msg.rs new file mode 100644 index 000000000..730ad51f4 --- /dev/null +++ b/vendor/lsp-server/src/msg.rs @@ -0,0 +1,351 @@ +use std::{ + fmt, + io::{self, BufRead, Write}, +}; + +use serde::{de::DeserializeOwned, Deserialize, Serialize}; + +use crate::error::ExtractError; + +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(untagged)] +pub enum Message { + Request(Request), + Response(Response), + Notification(Notification), +} + +impl From<Request> for Message { + fn from(request: Request) -> Message { + Message::Request(request) + } +} + +impl From<Response> for Message { + fn from(response: Response) -> Message { + Message::Response(response) + } +} + +impl From<Notification> for Message { + fn from(notification: Notification) -> Message { + Message::Notification(notification) + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[serde(transparent)] +pub struct RequestId(IdRepr); + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[serde(untagged)] +enum IdRepr { + I32(i32), + String(String), +} + +impl From<i32> for RequestId { + fn from(id: i32) -> RequestId { + RequestId(IdRepr::I32(id)) + } +} + +impl From<String> for RequestId { + fn from(id: String) -> RequestId { + RequestId(IdRepr::String(id)) + } +} + +impl fmt::Display for RequestId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + IdRepr::I32(it) => fmt::Display::fmt(it, f), + // Use debug here, to make it clear that `92` and `"92"` are + // different, and to reduce WTF factor if the sever uses `" "` as an + // ID. + IdRepr::String(it) => fmt::Debug::fmt(it, f), + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Request { + pub id: RequestId, + pub method: String, + #[serde(default = "serde_json::Value::default")] + #[serde(skip_serializing_if = "serde_json::Value::is_null")] + pub params: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Response { + // JSON RPC allows this to be null if it was impossible + // to decode the request's id. Ignore this special case + // and just die horribly. + pub id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option<serde_json::Value>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<ResponseError>, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResponseError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option<serde_json::Value>, +} + +#[derive(Clone, Copy, Debug)] +#[non_exhaustive] +pub enum ErrorCode { + // Defined by JSON RPC: + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, + ServerErrorStart = -32099, + ServerErrorEnd = -32000, + + /// Error code indicating that a server received a notification or + /// request before the server has received the `initialize` request. + ServerNotInitialized = -32002, + UnknownErrorCode = -32001, + + // Defined by the protocol: + /// The client has canceled a request and a server has detected + /// the cancel. + RequestCanceled = -32800, + + /// The server detected that the content of a document got + /// modified outside normal conditions. A server should + /// NOT send this error code if it detects a content change + /// in it unprocessed messages. The result even computed + /// on an older state might still be useful for the client. + /// + /// If a client decides that a result is not of any use anymore + /// the client should cancel the request. + ContentModified = -32801, + + /// The server cancelled the request. This error code should + /// only be used for requests that explicitly support being + /// server cancellable. + /// + /// @since 3.17.0 + ServerCancelled = -32802, + + /// A request failed but it was syntactically correct, e.g the + /// method name was known and the parameters were valid. The error + /// message should contain human readable information about why + /// the request failed. + /// + /// @since 3.17.0 + RequestFailed = -32803, +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Notification { + pub method: String, + #[serde(default = "serde_json::Value::default")] + #[serde(skip_serializing_if = "serde_json::Value::is_null")] + pub params: serde_json::Value, +} + +impl Message { + pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> { + Message::_read(r) + } + fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> { + let text = match read_msg_text(r)? { + None => return Ok(None), + Some(text) => text, + }; + let msg = serde_json::from_str(&text)?; + Ok(Some(msg)) + } + pub fn write(self, w: &mut impl Write) -> io::Result<()> { + self._write(w) + } + fn _write(self, w: &mut dyn Write) -> io::Result<()> { + #[derive(Serialize)] + struct JsonRpc { + jsonrpc: &'static str, + #[serde(flatten)] + msg: Message, + } + let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?; + write_msg_text(w, &text) + } +} + +impl Response { + pub fn new_ok<R: Serialize>(id: RequestId, result: R) -> Response { + Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None } + } + pub fn new_err(id: RequestId, code: i32, message: String) -> Response { + let error = ResponseError { code, message, data: None }; + Response { id, result: None, error: Some(error) } + } +} + +impl Request { + pub fn new<P: Serialize>(id: RequestId, method: String, params: P) -> Request { + Request { id, method, params: serde_json::to_value(params).unwrap() } + } + pub fn extract<P: DeserializeOwned>( + self, + method: &str, + ) -> Result<(RequestId, P), ExtractError<Request>> { + if self.method != method { + return Err(ExtractError::MethodMismatch(self)); + } + match serde_json::from_value(self.params) { + Ok(params) => Ok((self.id, params)), + Err(error) => Err(ExtractError::JsonError { method: self.method, error }), + } + } + + pub(crate) fn is_shutdown(&self) -> bool { + self.method == "shutdown" + } + pub(crate) fn is_initialize(&self) -> bool { + self.method == "initialize" + } +} + +impl Notification { + pub fn new(method: String, params: impl Serialize) -> Notification { + Notification { method, params: serde_json::to_value(params).unwrap() } + } + pub fn extract<P: DeserializeOwned>( + self, + method: &str, + ) -> Result<P, ExtractError<Notification>> { + if self.method != method { + return Err(ExtractError::MethodMismatch(self)); + } + match serde_json::from_value(self.params) { + Ok(params) => Ok(params), + Err(error) => Err(ExtractError::JsonError { method: self.method, error }), + } + } + pub(crate) fn is_exit(&self) -> bool { + self.method == "exit" + } + pub(crate) fn is_initialized(&self) -> bool { + self.method == "initialized" + } +} + +fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> { + fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error { + io::Error::new(io::ErrorKind::InvalidData, error) + } + macro_rules! invalid_data { + ($($tt:tt)*) => (invalid_data(format!($($tt)*))) + } + + let mut size = None; + let mut buf = String::new(); + loop { + buf.clear(); + if inp.read_line(&mut buf)? == 0 { + return Ok(None); + } + if !buf.ends_with("\r\n") { + return Err(invalid_data!("malformed header: {:?}", buf)); + } + let buf = &buf[..buf.len() - 2]; + if buf.is_empty() { + break; + } + let mut parts = buf.splitn(2, ": "); + let header_name = parts.next().unwrap(); + let header_value = + parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?; + if header_name.eq_ignore_ascii_case("Content-Length") { + size = Some(header_value.parse::<usize>().map_err(invalid_data)?); + } + } + let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?; + let mut buf = buf.into_bytes(); + buf.resize(size, 0); + inp.read_exact(&mut buf)?; + let buf = String::from_utf8(buf).map_err(invalid_data)?; + log::debug!("< {}", buf); + Ok(Some(buf)) +} + +fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> { + log::debug!("> {}", msg); + write!(out, "Content-Length: {}\r\n\r\n", msg.len())?; + out.write_all(msg.as_bytes())?; + out.flush()?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{Message, Notification, Request, RequestId}; + + #[test] + fn shutdown_with_explicit_null() { + let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!( + matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown") + ); + } + + #[test] + fn shutdown_with_no_params() { + let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!( + matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown") + ); + } + + #[test] + fn notification_with_explicit_null() { + let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!(matches!(msg, Message::Notification(not) if not.method == "exit")); + } + + #[test] + fn notification_with_no_params() { + let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}"; + let msg: Message = serde_json::from_str(text).unwrap(); + + assert!(matches!(msg, Message::Notification(not) if not.method == "exit")); + } + + #[test] + fn serialize_request_with_null_params() { + let msg = Message::Request(Request { + id: RequestId::from(3), + method: "shutdown".into(), + params: serde_json::Value::Null, + }); + let serialized = serde_json::to_string(&msg).unwrap(); + + assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized); + } + + #[test] + fn serialize_notification_with_null_params() { + let msg = Message::Notification(Notification { + method: "exit".into(), + params: serde_json::Value::Null, + }); + let serialized = serde_json::to_string(&msg).unwrap(); + + assert_eq!("{\"method\":\"exit\"}", serialized); + } +} diff --git a/vendor/lsp-server/src/req_queue.rs b/vendor/lsp-server/src/req_queue.rs new file mode 100644 index 000000000..e5f19be20 --- /dev/null +++ b/vendor/lsp-server/src/req_queue.rs @@ -0,0 +1,69 @@ +use std::collections::HashMap; + +use serde::Serialize; + +use crate::{ErrorCode, Request, RequestId, Response, ResponseError}; + +/// Manages the set of pending requests, both incoming and outgoing. +#[derive(Debug)] +pub struct ReqQueue<I, O> { + pub incoming: Incoming<I>, + pub outgoing: Outgoing<O>, +} + +impl<I, O> Default for ReqQueue<I, O> { + fn default() -> ReqQueue<I, O> { + ReqQueue { + incoming: Incoming { pending: HashMap::default() }, + outgoing: Outgoing { next_id: 0, pending: HashMap::default() }, + } + } +} + +#[derive(Debug)] +pub struct Incoming<I> { + pending: HashMap<RequestId, I>, +} + +#[derive(Debug)] +pub struct Outgoing<O> { + next_id: i32, + pending: HashMap<RequestId, O>, +} + +impl<I> Incoming<I> { + pub fn register(&mut self, id: RequestId, data: I) { + self.pending.insert(id, data); + } + + pub fn cancel(&mut self, id: RequestId) -> Option<Response> { + let _data = self.complete(id.clone())?; + let error = ResponseError { + code: ErrorCode::RequestCanceled as i32, + message: "canceled by client".to_string(), + data: None, + }; + Some(Response { id, result: None, error: Some(error) }) + } + + pub fn complete(&mut self, id: RequestId) -> Option<I> { + self.pending.remove(&id) + } + + pub fn is_completed(&self, id: &RequestId) -> bool { + !self.pending.contains_key(id) + } +} + +impl<O> Outgoing<O> { + pub fn register<P: Serialize>(&mut self, method: String, params: P, data: O) -> Request { + let id = RequestId::from(self.next_id); + self.pending.insert(id.clone(), data); + self.next_id += 1; + Request::new(id, method, params) + } + + pub fn complete(&mut self, id: RequestId) -> Option<O> { + self.pending.remove(&id) + } +} diff --git a/vendor/lsp-server/src/socket.rs b/vendor/lsp-server/src/socket.rs new file mode 100644 index 000000000..36d728456 --- /dev/null +++ b/vendor/lsp-server/src/socket.rs @@ -0,0 +1,46 @@ +use std::{ + io::{self, BufReader}, + net::TcpStream, + thread, +}; + +use crossbeam_channel::{bounded, Receiver, Sender}; + +use crate::{ + stdio::{make_io_threads, IoThreads}, + Message, +}; + +pub(crate) fn socket_transport( + stream: TcpStream, +) -> (Sender<Message>, Receiver<Message>, IoThreads) { + let (reader_receiver, reader) = make_reader(stream.try_clone().unwrap()); + let (writer_sender, writer) = make_write(stream); + let io_threads = make_io_threads(reader, writer); + (writer_sender, reader_receiver, io_threads) +} + +fn make_reader(stream: TcpStream) -> (Receiver<Message>, thread::JoinHandle<io::Result<()>>) { + let (reader_sender, reader_receiver) = bounded::<Message>(0); + let reader = thread::spawn(move || { + let mut buf_read = BufReader::new(stream); + while let Some(msg) = Message::read(&mut buf_read).unwrap() { + let is_exit = matches!(&msg, Message::Notification(n) if n.is_exit()); + reader_sender.send(msg).unwrap(); + if is_exit { + break; + } + } + Ok(()) + }); + (reader_receiver, reader) +} + +fn make_write(mut stream: TcpStream) -> (Sender<Message>, thread::JoinHandle<io::Result<()>>) { + let (writer_sender, writer_receiver) = bounded::<Message>(0); + let writer = thread::spawn(move || { + writer_receiver.into_iter().try_for_each(|it| it.write(&mut stream)).unwrap(); + Ok(()) + }); + (writer_sender, writer) +} diff --git a/vendor/lsp-server/src/stdio.rs b/vendor/lsp-server/src/stdio.rs new file mode 100644 index 000000000..e487b9b46 --- /dev/null +++ b/vendor/lsp-server/src/stdio.rs @@ -0,0 +1,71 @@ +use std::{ + io::{self, stdin, stdout}, + thread, +}; + +use log::debug; + +use crossbeam_channel::{bounded, Receiver, Sender}; + +use crate::Message; + +/// Creates an LSP connection via stdio. +pub(crate) fn stdio_transport() -> (Sender<Message>, Receiver<Message>, IoThreads) { + let (writer_sender, writer_receiver) = bounded::<Message>(0); + let writer = thread::spawn(move || { + let stdout = stdout(); + let mut stdout = stdout.lock(); + writer_receiver.into_iter().try_for_each(|it| it.write(&mut stdout))?; + Ok(()) + }); + let (reader_sender, reader_receiver) = bounded::<Message>(0); + let reader = thread::spawn(move || { + let stdin = stdin(); + let mut stdin = stdin.lock(); + while let Some(msg) = Message::read(&mut stdin)? { + let is_exit = matches!(&msg, Message::Notification(n) if n.is_exit()); + + debug!("sending message {:#?}", msg); + reader_sender.send(msg).expect("receiver was dropped, failed to send a message"); + + if is_exit { + break; + } + } + Ok(()) + }); + let threads = IoThreads { reader, writer }; + (writer_sender, reader_receiver, threads) +} + +// Creates an IoThreads +pub(crate) fn make_io_threads( + reader: thread::JoinHandle<io::Result<()>>, + writer: thread::JoinHandle<io::Result<()>>, +) -> IoThreads { + IoThreads { reader, writer } +} + +pub struct IoThreads { + reader: thread::JoinHandle<io::Result<()>>, + writer: thread::JoinHandle<io::Result<()>>, +} + +impl IoThreads { + pub fn join(self) -> io::Result<()> { + match self.reader.join() { + Ok(r) => r?, + Err(err) => { + println!("reader panicked!"); + std::panic::panic_any(err) + } + } + match self.writer.join() { + Ok(r) => r, + Err(err) => { + println!("writer panicked!"); + std::panic::panic_any(err); + } + } + } +} |