1use std::collections::HashMap;
2use std::io::{self, Read, Write};
3use std::net::{TcpListener, TcpStream};
4use std::sync::{Arc, Condvar, Mutex};
5use std::thread;
6
7type Callback = Box<dyn Fn(&[u8]) + Send>;
8
9#[derive(Clone, Copy, Debug, PartialEq, Eq)]
10enum ConnectionState {
11 WaitingForConnection,
12 Connected,
13 Closed,
14}
15
16struct Inner {
17 port: u16,
18 stream: Mutex<Option<TcpStream>>,
19 subscribers: Mutex<HashMap<String, Vec<Callback>>>,
20 connection_state: Mutex<ConnectionState>,
21 connection_state_changed: Condvar,
22}
23
24#[derive(Clone)]
26pub struct PubSub {
27 inner: Arc<Inner>,
28}
29
30impl PubSub {
31 pub fn listen(url: &str) -> Result<Self, io::Error> {
43 let port = parse_url(url)?;
44 let listener = TcpListener::bind(("127.0.0.1", port))?;
45 let actual_port = listener.local_addr()?.port();
46
47 let pubsub = PubSub {
48 inner: Arc::new(Inner {
49 port: actual_port,
50 stream: Mutex::new(None),
51 subscribers: Mutex::new(HashMap::new()),
52 connection_state: Mutex::new(ConnectionState::WaitingForConnection),
53 connection_state_changed: Condvar::new(),
54 }),
55 };
56
57 let inner = pubsub.inner.clone();
58 thread::Builder::new()
59 .name("elixirkit-pubsub".into())
60 .spawn(move || {
61 let Ok((tcp_stream, _)) = listener.accept() else {
62 set_connection_state(&inner, ConnectionState::Closed);
63 return;
64 };
65 let _ = tcp_stream.set_nodelay(true);
66
67 let reader = match tcp_stream.try_clone() {
68 Ok(r) => r,
69 Err(_) => {
70 set_connection_state(&inner, ConnectionState::Closed);
71 return;
72 }
73 };
74 *inner.stream.lock().unwrap() = Some(tcp_stream);
75 set_connection_state(&inner, ConnectionState::Connected);
76
77 read_loop(&inner, reader);
78 *inner.stream.lock().unwrap() = None;
79 set_connection_state(&inner, ConnectionState::Closed);
80 })?;
81
82 Ok(pubsub)
83 }
84
85 #[doc(hidden)]
87 pub fn connect(url: &str) -> Result<Self, io::Error> {
88 let port = parse_url(url)?;
89 let tcp_stream = TcpStream::connect(("127.0.0.1", port))?;
90 let _ = tcp_stream.set_nodelay(true);
91
92 let reader = tcp_stream.try_clone()?;
93
94 let pubsub = PubSub {
95 inner: Arc::new(Inner {
96 port,
97 stream: Mutex::new(Some(tcp_stream)),
98 subscribers: Mutex::new(HashMap::new()),
99 connection_state: Mutex::new(ConnectionState::Connected),
100 connection_state_changed: Condvar::new(),
101 }),
102 };
103
104 let inner = pubsub.inner.clone();
105 thread::Builder::new()
106 .name("elixirkit-pubsub".into())
107 .spawn(move || {
108 read_loop(&inner, reader);
109 *inner.stream.lock().unwrap() = None;
110 set_connection_state(&inner, ConnectionState::Closed);
111 })?;
112
113 Ok(pubsub)
114 }
115
116 pub fn url(&self) -> String {
118 format!("tcp://127.0.0.1:{}", self.inner.port)
119 }
120
121 pub fn subscribe<F>(&self, topic: &str, callback: F)
123 where
124 F: Fn(&[u8]) + Send + 'static,
125 {
126 self.inner
127 .subscribers
128 .lock()
129 .unwrap()
130 .entry(topic.to_string())
131 .or_default()
132 .push(Box::new(callback));
133 }
134
135 pub fn broadcast(&self, topic: &str, message: &[u8]) -> io::Result<()> {
137 if topic.len() > 255 {
138 return Err(io::Error::new(
139 io::ErrorKind::InvalidInput,
140 "topic must be at most 255 bytes",
141 ));
142 }
143 let mut state = self.inner.connection_state.lock().unwrap();
144 while *state == ConnectionState::WaitingForConnection {
145 state = self.inner.connection_state_changed.wait(state).unwrap();
146 }
147 drop(state);
148
149 let guard = self.inner.stream.lock().unwrap();
150 match guard.as_ref() {
151 Some(stream) => write_message(stream, topic.as_bytes(), message),
152 None => Err(io::Error::new(io::ErrorKind::NotConnected, "not connected")),
153 }
154 }
155
156 #[doc(hidden)]
158 pub fn wait(&self) {
159 let mut state = self.inner.connection_state.lock().unwrap();
160 while *state != ConnectionState::Closed {
161 state = self.inner.connection_state_changed.wait(state).unwrap();
162 }
163 }
164}
165
166fn read_loop(inner: &Inner, mut reader: TcpStream) {
167 loop {
168 let mut len_buf = [0u8; 4];
169 if reader.read_exact(&mut len_buf).is_err() {
170 break;
171 }
172 let frame_len = u32::from_be_bytes(len_buf) as usize;
173
174 let mut frame = vec![0u8; frame_len];
175 if reader.read_exact(&mut frame).is_err() {
176 break;
177 }
178
179 if frame.is_empty() {
180 continue;
181 }
182 let topic_len = frame[0] as usize;
183 if frame.len() < 1 + topic_len {
184 continue;
185 }
186 let topic = &frame[1..1 + topic_len];
187 let payload = &frame[1 + topic_len..];
188
189 let topic_str = String::from_utf8_lossy(topic);
190
191 let subscribers = inner.subscribers.lock().unwrap();
192 if let Some(callbacks) = subscribers.get(topic_str.as_ref()) {
193 for cb in callbacks {
194 cb(payload);
195 }
196 }
197 }
198}
199
200fn parse_url(url: &str) -> Result<u16, io::Error> {
201 url.strip_prefix("tcp://127.0.0.1:")
202 .and_then(|port| port.parse::<u16>().ok())
203 .ok_or_else(|| {
204 io::Error::new(
205 io::ErrorKind::InvalidInput,
206 format!("expected tcp://127.0.0.1:{{port}}, got: {:?}", url),
207 )
208 })
209}
210
211fn set_connection_state(inner: &Inner, state: ConnectionState) {
212 *inner.connection_state.lock().unwrap() = state;
213 inner.connection_state_changed.notify_all();
214}
215
216fn write_message(mut stream: &TcpStream, topic: &[u8], payload: &[u8]) -> io::Result<()> {
217 let inner_len = 1 + topic.len() + payload.len();
218 let frame_len = u32::try_from(inner_len)
219 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "frame too large"))?;
220
221 stream.write_all(&frame_len.to_be_bytes())?;
222 stream.write_all(&[topic.len() as u8])?;
223 stream.write_all(topic)?;
224 stream.write_all(payload)?;
225 stream.flush()
226}