elixirkit/
pubsub.rs

1use std::collections::HashMap;
2use std::io::{self, Read, Write};
3use std::net::{TcpListener, TcpStream};
4use std::sync::{Arc, Condvar, Mutex};
5use std::thread;
6
7type Callback = Box<dyn Fn(&[u8]) + Send>;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10enum ConnectionState {
11    WaitingForConnection,
12    Connected,
13    Closed,
14}
15
16struct Inner {
17    port: u16,
18    stream: Mutex<Option<TcpStream>>,
19    subscribers: Mutex<HashMap<String, Vec<Callback>>>,
20    connection_state: Mutex<ConnectionState>,
21    connection_state_changed: Condvar,
22}
23
24/// A handle to the PubSub connection.
25#[derive(Clone)]
26pub struct PubSub {
27    inner: Arc<Inner>,
28}
29
30impl PubSub {
31    /// Listens on the given URL and returns a [`PubSub`] handle.
32    ///
33    /// The URL must be in the format `tcp://127.0.0.1:{port}`, where port
34    /// can be `0` to let the OS assign an available port.
35    ///
36    /// # Examples
37    ///
38    /// ```no_run
39    /// let pubsub = elixirkit::PubSub::listen("tcp://127.0.0.1:0")
40    ///     .expect("failed to listen");
41    /// ```
42    pub fn listen(url: &str) -> Result<Self, io::Error> {
43        let port = parse_url(url)?;
44        let listener = TcpListener::bind(("127.0.0.1", port))?;
45        let actual_port = listener.local_addr()?.port();
46
47        let pubsub = PubSub {
48            inner: Arc::new(Inner {
49                port: actual_port,
50                stream: Mutex::new(None),
51                subscribers: Mutex::new(HashMap::new()),
52                connection_state: Mutex::new(ConnectionState::WaitingForConnection),
53                connection_state_changed: Condvar::new(),
54            }),
55        };
56
57        let inner = pubsub.inner.clone();
58        thread::Builder::new()
59            .name("elixirkit-pubsub".into())
60            .spawn(move || {
61                let Ok((tcp_stream, _)) = listener.accept() else {
62                    set_connection_state(&inner, ConnectionState::Closed);
63                    return;
64                };
65                let _ = tcp_stream.set_nodelay(true);
66
67                let reader = match tcp_stream.try_clone() {
68                    Ok(r) => r,
69                    Err(_) => {
70                        set_connection_state(&inner, ConnectionState::Closed);
71                        return;
72                    }
73                };
74                *inner.stream.lock().unwrap() = Some(tcp_stream);
75                set_connection_state(&inner, ConnectionState::Connected);
76
77                read_loop(&inner, reader);
78                *inner.stream.lock().unwrap() = None;
79                set_connection_state(&inner, ConnectionState::Closed);
80            })?;
81
82        Ok(pubsub)
83    }
84
85    // TODO: not documented, used just for testing for now.
86    #[doc(hidden)]
87    pub fn connect(url: &str) -> Result<Self, io::Error> {
88        let port = parse_url(url)?;
89        let tcp_stream = TcpStream::connect(("127.0.0.1", port))?;
90        let _ = tcp_stream.set_nodelay(true);
91
92        let reader = tcp_stream.try_clone()?;
93
94        let pubsub = PubSub {
95            inner: Arc::new(Inner {
96                port,
97                stream: Mutex::new(Some(tcp_stream)),
98                subscribers: Mutex::new(HashMap::new()),
99                connection_state: Mutex::new(ConnectionState::Connected),
100                connection_state_changed: Condvar::new(),
101            }),
102        };
103
104        let inner = pubsub.inner.clone();
105        thread::Builder::new()
106            .name("elixirkit-pubsub".into())
107            .spawn(move || {
108                read_loop(&inner, reader);
109                *inner.stream.lock().unwrap() = None;
110                set_connection_state(&inner, ConnectionState::Closed);
111            })?;
112
113        Ok(pubsub)
114    }
115
116    /// Returns the URL for this PubSub connection.
117    pub fn url(&self) -> String {
118        format!("tcp://127.0.0.1:{}", self.inner.port)
119    }
120
121    /// Subscribes to messages on the given topic from the Elixir side.
122    pub fn subscribe<F>(&self, topic: &str, callback: F)
123    where
124        F: Fn(&[u8]) + Send + 'static,
125    {
126        self.inner
127            .subscribers
128            .lock()
129            .unwrap()
130            .entry(topic.to_string())
131            .or_default()
132            .push(Box::new(callback));
133    }
134
135    /// Broadcasts a message on the given topic to the Elixir side.
136    pub fn broadcast(&self, topic: &str, message: &[u8]) -> io::Result<()> {
137        if topic.len() > 255 {
138            return Err(io::Error::new(
139                io::ErrorKind::InvalidInput,
140                "topic must be at most 255 bytes",
141            ));
142        }
143        let mut state = self.inner.connection_state.lock().unwrap();
144        while *state == ConnectionState::WaitingForConnection {
145            state = self.inner.connection_state_changed.wait(state).unwrap();
146        }
147        drop(state);
148
149        let guard = self.inner.stream.lock().unwrap();
150        match guard.as_ref() {
151            Some(stream) => write_message(stream, topic.as_bytes(), message),
152            None => Err(io::Error::new(io::ErrorKind::NotConnected, "not connected")),
153        }
154    }
155
156    // TODO: not documented, used just for testing for now.
157    #[doc(hidden)]
158    pub fn wait(&self) {
159        let mut state = self.inner.connection_state.lock().unwrap();
160        while *state != ConnectionState::Closed {
161            state = self.inner.connection_state_changed.wait(state).unwrap();
162        }
163    }
164}
165
166fn read_loop(inner: &Inner, mut reader: TcpStream) {
167    loop {
168        let mut len_buf = [0u8; 4];
169        if reader.read_exact(&mut len_buf).is_err() {
170            break;
171        }
172        let frame_len = u32::from_be_bytes(len_buf) as usize;
173
174        let mut frame = vec![0u8; frame_len];
175        if reader.read_exact(&mut frame).is_err() {
176            break;
177        }
178
179        if frame.is_empty() {
180            continue;
181        }
182        let topic_len = frame[0] as usize;
183        if frame.len() < 1 + topic_len {
184            continue;
185        }
186        let topic = &frame[1..1 + topic_len];
187        let payload = &frame[1 + topic_len..];
188
189        let topic_str = String::from_utf8_lossy(topic);
190
191        let subscribers = inner.subscribers.lock().unwrap();
192        if let Some(callbacks) = subscribers.get(topic_str.as_ref()) {
193            for cb in callbacks {
194                cb(payload);
195            }
196        }
197    }
198}
199
200fn parse_url(url: &str) -> Result<u16, io::Error> {
201    url.strip_prefix("tcp://127.0.0.1:")
202        .and_then(|port| port.parse::<u16>().ok())
203        .ok_or_else(|| {
204            io::Error::new(
205                io::ErrorKind::InvalidInput,
206                format!("expected tcp://127.0.0.1:{{port}}, got: {:?}", url),
207            )
208        })
209}
210
211fn set_connection_state(inner: &Inner, state: ConnectionState) {
212    *inner.connection_state.lock().unwrap() = state;
213    inner.connection_state_changed.notify_all();
214}
215
216fn write_message(mut stream: &TcpStream, topic: &[u8], payload: &[u8]) -> io::Result<()> {
217    let inner_len = 1 + topic.len() + payload.len();
218    let frame_len = u32::try_from(inner_len)
219        .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame too large"))?;
220
221    stream.write_all(&frame_len.to_be_bytes())?;
222    stream.write_all(&[topic.len() as u8])?;
223    stream.write_all(topic)?;
224    stream.write_all(payload)?;
225    stream.flush()
226}