chat_core/
coordinator.rs

1use std::collections::{HashSet};
2use std::time::Duration;
3
4use async_broadcast::Sender as BroadcastSender;
5use async_channel::Sender as ChannelSender;
6use base64::{engine::general_purpose::STANDARD, Engine as _};
7use futures::{select, FutureExt, StreamExt};
8use futures::stream;
9use futures_timer::Delay;
10use yrs::updates::encoder::Encode;
11use yrs::{Doc, MapRef, ReadTxn, Transact, UpdateEvent};
12
13use crate::commands::handlers::handle_command;
14use crate::history::types::{ChatMessage, HistoryRequest};
15use crate::messages::processor::{handle_yjs_update, YjsUpdateMeta};
16use crate::network::{NetworkNode, PeerId, RoomId};
17use crate::network::handlers::handle_network_event;
18use crate::protocol::hello::{build_hello_message, HELLO_INTERVAL_SECS};
19use crate::stats::{build_connection_stats, ConnectionStats};
20use crate::storage::SqliteStorage;
21use crate::sync::scheduler::SyncScheduler;
22
23/// Commands sent from the UI (or headless API) into the coordinator.
24pub enum CoordinatorCommand {
25    JoinRoom { room_name: String },
26    LeaveRoom,
27    SendMessage { content: String },
28    DialPeer { multiaddr: String },
29    SetNickname { nickname: String },
30    Shutdown,
31}
32
33/// Events broadcast from the coordinator to UI subscribers.
34#[derive(Debug, Clone)]
35pub enum CoordinatorEvent {
36    MessagesUpdated(Vec<ChatMessage>),
37    PeerConnected { peer_id: PeerId },
38    PeerDisconnected { peer_id: PeerId },
39    StatsUpdated(Box<ConnectionStats>),
40    SystemMessage(String),
41    ErrorMessage(String),
42    RoomJoined { room_name: String },
43    NicknameChanged { nickname: String },
44    IsOnline(bool),
45}
46
47/// A cloneable handle used by the UI to drive the coordinator.
48#[derive(Clone)]
49pub struct ChatCoordinatorHandle {
50    cmd_tx: ChannelSender<CoordinatorCommand>,
51}
52
53impl ChatCoordinatorHandle {
54    pub async fn join_room(&self, room_name: String) {
55        let _ = self
56            .cmd_tx
57            .send(CoordinatorCommand::JoinRoom { room_name })
58            .await;
59    }
60
61    pub async fn leave_room(&self) {
62        let _ = self.cmd_tx.send(CoordinatorCommand::LeaveRoom).await;
63    }
64
65    pub async fn send_message(&self, content: String) {
66        let _ = self
67            .cmd_tx
68            .send(CoordinatorCommand::SendMessage { content })
69            .await;
70    }
71
72    pub async fn dial_peer(&self, multiaddr: String) {
73        let _ = self
74            .cmd_tx
75            .send(CoordinatorCommand::DialPeer { multiaddr })
76            .await;
77    }
78
79    pub async fn set_nickname(&self, nickname: String) {
80        let _ = self
81            .cmd_tx
82            .send(CoordinatorCommand::SetNickname { nickname })
83            .await;
84    }
85
86    pub async fn shutdown(&self) {
87        let _ = self.cmd_tx.send(CoordinatorCommand::Shutdown).await;
88    }
89}
90
91pub struct CoordinatorState {
92    pub room_name: Option<String>,
93    pub room_topic: Option<String>,
94    pub room_id: Option<RoomId>,
95    pub nickname: String,
96    pub stable_sender_id: String,
97    pub seq: i64,
98    pub connected_peers: HashSet<PeerId>,
99    pub seen_peers: HashSet<PeerId>,
100    pub is_online: bool,
101    pub local_peer_id: PeerId,
102    pub sync: SyncScheduler,
103    pub room_peers: HashSet<PeerId>,
104    pub messages: Vec<ChatMessage>,
105    pub seen_message_ids: HashSet<String>,
106}
107
108/// The coordinator event loop future. Call `.run().await` to drive it.
109pub struct CoordinatorLoop<N: NetworkNode> {
110    node: N,
111    storage: SqliteStorage,
112    cmd_rx: async_channel::Receiver<CoordinatorCommand>,
113    event_tx: BroadcastSender<CoordinatorEvent>,
114    doc: Doc,
115    map: MapRef,
116    yjs_rx: async_channel::Receiver<(Vec<u8>, bool, bool, bool)>,
117    hello_interval_secs: u64,
118}
119
120impl<N: NetworkNode> CoordinatorLoop<N> {
121    pub async fn run(mut self) {
122        let local_peer_id = self.node.local_id();
123
124        let stable_sender_id = self
125            .storage
126            .get_stable_sender_id()
127            .unwrap_or_else(|_| local_peer_id.as_str().to_string());
128
129        let mut state = CoordinatorState {
130            room_name: None,
131            room_topic: None,
132            room_id: None,
133            nickname: "anon".into(),
134            stable_sender_id,
135            seq: 0,
136            connected_peers: HashSet::new(),
137            seen_peers: HashSet::new(),
138            is_online: false,
139            local_peer_id: local_peer_id.clone(),
140            sync: SyncScheduler::new(),
141            room_peers: HashSet::new(),
142            messages: Vec::new(),
143            seen_message_ids: HashSet::new(),
144        };
145
146        let node = &mut self.node;
147        let storage = &self.storage;
148        let doc = &self.doc;
149        let map = &self.map;
150        let event_tx = &self.event_tx;
151        let cmd_rx = &self.cmd_rx;
152        let yjs_rx = &self.yjs_rx;
153
154        let mut stats_interval = Box::pin(stream::unfold((), |()| async {
155            Delay::new(Duration::from_secs(5)).await;
156            Some(((), ()))
157        }));
158
159        // Periodic hello messages to advertise our dialable addresses.
160        let hello_interval_secs = self.hello_interval_secs;
161        let mut hello_interval = Box::pin(stream::unfold((), move |()| async move {
162            Delay::new(Duration::from_secs(hello_interval_secs)).await;
163            Some(((), ()))
164        }));
165
166        // Trigger an initial bootstrap so client-mode peers populate their
167        // routing table quickly instead of waiting 60s.
168        if let Err(e) = node.bootstrap() {
169            tracing::debug!("Initial bootstrap failed (no known peers yet): {}", e);
170        }
171
172        loop {
173            // Compute delay until the next scheduled state-vector sync.
174            // The timer will fire exactly when the earliest sync is due.
175            let sync_delay = state.sync.next_due_delay(web_time::Instant::now());
176
177            select! {
178                cmd = cmd_rx.recv().fuse() => {
179                    match cmd {
180                        Err(_) => break,
181                        Ok(CoordinatorCommand::Shutdown) => break,
182                        Ok(cmd) => {
183                            if let Err(e) = handle_command(
184                                cmd, node, storage, doc, map, &mut state, event_tx
185                            ).await {
186                                broadcast_nonblocking(event_tx, CoordinatorEvent::ErrorMessage(format!("{}", e)));
187                            }
188                        }
189                    }
190                }
191
192                event = node.next_event().fuse() => {
193                    if let Some(ev) = event {
194                        handle_network_event(
195                            ev, node, doc, &mut state, event_tx
196                        ).await;
197                    }
198                }
199
200                yjs_msg = yjs_rx.recv().fuse() => {
201                    if let Ok((update_bytes, is_remote, is_clear, is_load)) = yjs_msg {
202                        handle_yjs_update(
203                            YjsUpdateMeta { update_bytes, is_remote, is_clear, is_load },
204                            node, storage, doc, map, &mut state, event_tx
205                        ).await;
206                    }
207                }
208
209                _ = stats_interval.next().fuse() => {
210                    let room_peer_strs: Vec<String> = state.room_peers.iter().map(|p| p.as_str().to_string()).collect();
211                    let stats = build_connection_stats(node.raw_stats(), room_peer_strs);
212                    broadcast_nonblocking(event_tx, CoordinatorEvent::StatsUpdated(Box::new(stats)));
213                }
214
215                _ = async {
216                    if let Some(d) = sync_delay {
217                        Delay::new(d).await;
218                    } else {
219                        futures::future::pending().await
220                    }
221                }.fuse() => {
222                    let now = web_time::Instant::now();
223                    let due_peers = state.sync.pop_due(now, &state.room_peers);
224                    for peer_id in due_peers {
225                        let sv = doc.transact().state_vector();
226                        let sv_bytes = sv.encode_v2();
227                        let sv_base64 = STANDARD.encode(&sv_bytes);
228                        let request = HistoryRequest::Sync {
229                            state_vector_base64: sv_base64,
230                        };
231                        if let Ok(data) = serde_json::to_vec(&request) {
232                            let _ = node.send_message(&peer_id, data);
233                        }
234                    }
235                }
236
237                _ = hello_interval.next().fuse() => {
238                    if state.is_online
239                        && state.room_name.is_some()
240                        && let Some(hello_bytes) = build_hello_message(node, &state.local_peer_id)
241                        && let Err(e) = node.publish_message(hello_bytes)
242                    {
243                        tracing::warn!("Failed to publish hello message: {}", e);
244                    }
245                }
246            }
247        }
248
249        tracing::info!("Coordinator event loop exited");
250    }
251}
252
253/// Build the coordinator and return a handle, event receiver, and the loop future.
254pub fn build<N: NetworkNode>(
255    node: N,
256    storage: SqliteStorage,
257) -> (
258    ChatCoordinatorHandle,
259    async_broadcast::Receiver<CoordinatorEvent>,
260    CoordinatorLoop<N>,
261) {
262    build_with_hello_interval(node, storage, HELLO_INTERVAL_SECS)
263}
264
265/// Build the coordinator with a configurable hello interval (useful for tests).
266pub fn build_with_hello_interval<N: NetworkNode>(
267    node: N,
268    storage: SqliteStorage,
269    hello_interval_secs: u64,
270) -> (
271    ChatCoordinatorHandle,
272    async_broadcast::Receiver<CoordinatorEvent>,
273    CoordinatorLoop<N>,
274) {
275    let (cmd_tx, cmd_rx) = async_channel::bounded(32);
276    let (event_tx, event_rx) = async_broadcast::broadcast(256);
277
278    let doc = Doc::new();
279    let map = doc.get_or_insert_map("messages");
280
281    // Bridge yrs update observer into an async channel.
282    // The subscription handle is not Send, so we create it here (synchronous)
283    // and leak it, passing only the Send receiver into the async task.
284    let (yjs_tx, yjs_rx) = async_channel::unbounded::<(Vec<u8>, bool, bool, bool)>();
285    let yjs_sub = doc.observe_update_v2(move |txn, event: &UpdateEvent| {
286        let origin = txn.origin().map(|o| o.as_ref());
287        let is_remote = origin.map(|bytes| bytes == "remote".as_bytes()).unwrap_or(false);
288        let is_clear = origin.map(|bytes| bytes == "clear".as_bytes()).unwrap_or(false);
289        let is_load = origin.map(|bytes| bytes == "load".as_bytes()).unwrap_or(false);
290        let _ = yjs_tx.try_send((event.update.clone(), is_remote, is_clear, is_load));
291    });
292    if let Ok(sub) = yjs_sub {
293        std::mem::forget(sub);
294    }
295
296    let loop_fut = CoordinatorLoop {
297        node,
298        storage,
299        cmd_rx,
300        event_tx,
301        doc,
302        map,
303        yjs_rx,
304        hello_interval_secs,
305    };
306
307    (ChatCoordinatorHandle { cmd_tx }, event_rx, loop_fut)
308}
309
310/// Broadcast an event without blocking. If the channel is full (UI is busy),
311/// log a warning and drop the event so the coordinator loop stays responsive.
312pub(crate) fn broadcast_nonblocking(
313    event_tx: &BroadcastSender<CoordinatorEvent>,
314    event: CoordinatorEvent,
315) {
316    if let Err(e) = event_tx.try_broadcast(event) {
317        tracing::warn!("Failed to broadcast event: {}", e);
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::network::mock::MockNetworkNode;
325    use crate::network::{HelloMessage, NetworkEvent};
326    use std::time::Duration;
327
328    fn in_memory_storage() -> SqliteStorage {
329        let conn = rusqlite::Connection::open_in_memory().unwrap();
330        let storage = SqliteStorage::new(conn);
331        storage.init_schema().unwrap();
332        storage
333    }
334
335    #[test]
336    fn hello_message_triggers_auto_dial() {
337        let rt = tokio::runtime::Runtime::new().unwrap();
338
339        let mut mock = MockNetworkNode::new("local-peer".to_string());
340        mock.set_dial_addrs(vec![]);
341        let dial_log = mock.dial_log.clone();
342        let injector = mock.event_tx.clone();
343        let storage = in_memory_storage();
344
345        let (handle, _events, loop_fut) = build_with_hello_interval(mock, storage, 3600);
346        rt.spawn(async move {
347            loop_fut.run().await;
348        });
349
350        rt.block_on(async {
351            handle.join_room("test-room".to_string()).await;
352            tokio::time::sleep(Duration::from_millis(100)).await;
353
354            let hello = HelloMessage {
355                peer_id: "remote-peer".to_string(),
356                circuit_address: Some(
357                    "/dns4/relay.example.com/tcp/443/wss/p2p/relay-peer/p2p-circuit/p2p/remote-peer"
358                        .to_string(),
359                ),
360                web_rtc_address: None,
361            };
362            let hello_bytes = serde_json::to_vec(&hello).unwrap();
363            let event = NetworkEvent::BroadcastReceived {
364                peer_id: PeerId::from("remote-peer"),
365                data: hello_bytes,
366            };
367            injector.send(event).await.ok();
368            tokio::time::sleep(Duration::from_millis(100)).await;
369
370            let dials = dial_log.lock().unwrap();
371            assert!(
372                dials.iter().any(|a| a.as_str().contains("relay.example.com")),
373                "Expected auto-dial of circuit address from hello message, got: {:?}",
374                *dials
375            );
376
377            let _ = handle.shutdown().await;
378        });
379    }
380
381    #[test]
382    fn hello_interval_publishes_hello_message() {
383        let rt = tokio::runtime::Runtime::new().unwrap();
384
385        let mut mock = MockNetworkNode::new("local-peer".to_string());
386        let circuit = "/dns4/relay.example.com/tcp/443/wss/p2p/relay-peer/p2p-circuit/p2p/local-peer"
387            .to_string();
388        mock.set_dial_addrs(vec![circuit.clone()]);
389        let published = mock.published_messages.clone();
390        let storage = in_memory_storage();
391
392        let (handle, _events, loop_fut) = build_with_hello_interval(mock, storage, 1);
393        rt.spawn(async move {
394            loop_fut.run().await;
395        });
396
397        rt.block_on(async {
398            handle.join_room("test-room".to_string()).await;
399            // Wait for the hello interval to fire (1s) plus margin.
400            tokio::time::sleep(Duration::from_secs(2)).await;
401
402            let msgs = published.lock().unwrap();
403            assert!(
404                !msgs.is_empty(),
405                "Expected at least one hello message published"
406            );
407
408            let hello: HelloMessage = serde_json::from_slice(&msgs[0]).unwrap();
409            assert_eq!(hello.peer_id, "local-peer");
410            assert_eq!(hello.circuit_address, Some(circuit));
411            assert_eq!(hello.web_rtc_address, None);
412
413            let _ = handle.shutdown().await;
414        });
415    }
416}