1use std::collections::{HashSet};
2use std::time::Duration;
3
4use async_broadcast::Sender as BroadcastSender;
5use async_channel::Sender as ChannelSender;
6use base64::{engine::general_purpose::STANDARD, Engine as _};
7use futures::{select, FutureExt, StreamExt};
8use futures::stream;
9use futures_timer::Delay;
10use yrs::updates::encoder::Encode;
11use yrs::{Doc, MapRef, ReadTxn, Transact, UpdateEvent};
12
13use crate::commands::handlers::handle_command;
14use crate::history::types::{ChatMessage, HistoryRequest};
15use crate::messages::processor::{handle_yjs_update, YjsUpdateMeta};
16use crate::network::{NetworkNode, PeerId, RoomId};
17use crate::network::handlers::handle_network_event;
18use crate::protocol::hello::{build_hello_message, HELLO_INTERVAL_SECS};
19use crate::stats::{build_connection_stats, ConnectionStats};
20use crate::storage::SqliteStorage;
21use crate::sync::scheduler::SyncScheduler;
22
23pub enum CoordinatorCommand {
25 JoinRoom { room_name: String },
26 LeaveRoom,
27 SendMessage { content: String },
28 DialPeer { multiaddr: String },
29 SetNickname { nickname: String },
30 Shutdown,
31}
32
33#[derive(Debug, Clone)]
35pub enum CoordinatorEvent {
36 MessagesUpdated(Vec<ChatMessage>),
37 PeerConnected { peer_id: PeerId },
38 PeerDisconnected { peer_id: PeerId },
39 StatsUpdated(Box<ConnectionStats>),
40 SystemMessage(String),
41 ErrorMessage(String),
42 RoomJoined { room_name: String },
43 NicknameChanged { nickname: String },
44 IsOnline(bool),
45}
46
47#[derive(Clone)]
49pub struct ChatCoordinatorHandle {
50 cmd_tx: ChannelSender<CoordinatorCommand>,
51}
52
53impl ChatCoordinatorHandle {
54 pub async fn join_room(&self, room_name: String) {
55 let _ = self
56 .cmd_tx
57 .send(CoordinatorCommand::JoinRoom { room_name })
58 .await;
59 }
60
61 pub async fn leave_room(&self) {
62 let _ = self.cmd_tx.send(CoordinatorCommand::LeaveRoom).await;
63 }
64
65 pub async fn send_message(&self, content: String) {
66 let _ = self
67 .cmd_tx
68 .send(CoordinatorCommand::SendMessage { content })
69 .await;
70 }
71
72 pub async fn dial_peer(&self, multiaddr: String) {
73 let _ = self
74 .cmd_tx
75 .send(CoordinatorCommand::DialPeer { multiaddr })
76 .await;
77 }
78
79 pub async fn set_nickname(&self, nickname: String) {
80 let _ = self
81 .cmd_tx
82 .send(CoordinatorCommand::SetNickname { nickname })
83 .await;
84 }
85
86 pub async fn shutdown(&self) {
87 let _ = self.cmd_tx.send(CoordinatorCommand::Shutdown).await;
88 }
89}
90
91pub struct CoordinatorState {
92 pub room_name: Option<String>,
93 pub room_topic: Option<String>,
94 pub room_id: Option<RoomId>,
95 pub nickname: String,
96 pub stable_sender_id: String,
97 pub seq: i64,
98 pub connected_peers: HashSet<PeerId>,
99 pub seen_peers: HashSet<PeerId>,
100 pub is_online: bool,
101 pub local_peer_id: PeerId,
102 pub sync: SyncScheduler,
103 pub room_peers: HashSet<PeerId>,
104 pub messages: Vec<ChatMessage>,
105 pub seen_message_ids: HashSet<String>,
106}
107
108pub struct CoordinatorLoop<N: NetworkNode> {
110 node: N,
111 storage: SqliteStorage,
112 cmd_rx: async_channel::Receiver<CoordinatorCommand>,
113 event_tx: BroadcastSender<CoordinatorEvent>,
114 doc: Doc,
115 map: MapRef,
116 yjs_rx: async_channel::Receiver<(Vec<u8>, bool, bool, bool)>,
117 hello_interval_secs: u64,
118}
119
120impl<N: NetworkNode> CoordinatorLoop<N> {
121 pub async fn run(mut self) {
122 let local_peer_id = self.node.local_id();
123
124 let stable_sender_id = self
125 .storage
126 .get_stable_sender_id()
127 .unwrap_or_else(|_| local_peer_id.as_str().to_string());
128
129 let mut state = CoordinatorState {
130 room_name: None,
131 room_topic: None,
132 room_id: None,
133 nickname: "anon".into(),
134 stable_sender_id,
135 seq: 0,
136 connected_peers: HashSet::new(),
137 seen_peers: HashSet::new(),
138 is_online: false,
139 local_peer_id: local_peer_id.clone(),
140 sync: SyncScheduler::new(),
141 room_peers: HashSet::new(),
142 messages: Vec::new(),
143 seen_message_ids: HashSet::new(),
144 };
145
146 let node = &mut self.node;
147 let storage = &self.storage;
148 let doc = &self.doc;
149 let map = &self.map;
150 let event_tx = &self.event_tx;
151 let cmd_rx = &self.cmd_rx;
152 let yjs_rx = &self.yjs_rx;
153
154 let mut stats_interval = Box::pin(stream::unfold((), |()| async {
155 Delay::new(Duration::from_secs(5)).await;
156 Some(((), ()))
157 }));
158
159 let hello_interval_secs = self.hello_interval_secs;
161 let mut hello_interval = Box::pin(stream::unfold((), move |()| async move {
162 Delay::new(Duration::from_secs(hello_interval_secs)).await;
163 Some(((), ()))
164 }));
165
166 if let Err(e) = node.bootstrap() {
169 tracing::debug!("Initial bootstrap failed (no known peers yet): {}", e);
170 }
171
172 loop {
173 let sync_delay = state.sync.next_due_delay(web_time::Instant::now());
176
177 select! {
178 cmd = cmd_rx.recv().fuse() => {
179 match cmd {
180 Err(_) => break,
181 Ok(CoordinatorCommand::Shutdown) => break,
182 Ok(cmd) => {
183 if let Err(e) = handle_command(
184 cmd, node, storage, doc, map, &mut state, event_tx
185 ).await {
186 broadcast_nonblocking(event_tx, CoordinatorEvent::ErrorMessage(format!("{}", e)));
187 }
188 }
189 }
190 }
191
192 event = node.next_event().fuse() => {
193 if let Some(ev) = event {
194 handle_network_event(
195 ev, node, doc, &mut state, event_tx
196 ).await;
197 }
198 }
199
200 yjs_msg = yjs_rx.recv().fuse() => {
201 if let Ok((update_bytes, is_remote, is_clear, is_load)) = yjs_msg {
202 handle_yjs_update(
203 YjsUpdateMeta { update_bytes, is_remote, is_clear, is_load },
204 node, storage, doc, map, &mut state, event_tx
205 ).await;
206 }
207 }
208
209 _ = stats_interval.next().fuse() => {
210 let room_peer_strs: Vec<String> = state.room_peers.iter().map(|p| p.as_str().to_string()).collect();
211 let stats = build_connection_stats(node.raw_stats(), room_peer_strs);
212 broadcast_nonblocking(event_tx, CoordinatorEvent::StatsUpdated(Box::new(stats)));
213 }
214
215 _ = async {
216 if let Some(d) = sync_delay {
217 Delay::new(d).await;
218 } else {
219 futures::future::pending().await
220 }
221 }.fuse() => {
222 let now = web_time::Instant::now();
223 let due_peers = state.sync.pop_due(now, &state.room_peers);
224 for peer_id in due_peers {
225 let sv = doc.transact().state_vector();
226 let sv_bytes = sv.encode_v2();
227 let sv_base64 = STANDARD.encode(&sv_bytes);
228 let request = HistoryRequest::Sync {
229 state_vector_base64: sv_base64,
230 };
231 if let Ok(data) = serde_json::to_vec(&request) {
232 let _ = node.send_message(&peer_id, data);
233 }
234 }
235 }
236
237 _ = hello_interval.next().fuse() => {
238 if state.is_online
239 && state.room_name.is_some()
240 && let Some(hello_bytes) = build_hello_message(node, &state.local_peer_id)
241 && let Err(e) = node.publish_message(hello_bytes)
242 {
243 tracing::warn!("Failed to publish hello message: {}", e);
244 }
245 }
246 }
247 }
248
249 tracing::info!("Coordinator event loop exited");
250 }
251}
252
253pub fn build<N: NetworkNode>(
255 node: N,
256 storage: SqliteStorage,
257) -> (
258 ChatCoordinatorHandle,
259 async_broadcast::Receiver<CoordinatorEvent>,
260 CoordinatorLoop<N>,
261) {
262 build_with_hello_interval(node, storage, HELLO_INTERVAL_SECS)
263}
264
265pub fn build_with_hello_interval<N: NetworkNode>(
267 node: N,
268 storage: SqliteStorage,
269 hello_interval_secs: u64,
270) -> (
271 ChatCoordinatorHandle,
272 async_broadcast::Receiver<CoordinatorEvent>,
273 CoordinatorLoop<N>,
274) {
275 let (cmd_tx, cmd_rx) = async_channel::bounded(32);
276 let (event_tx, event_rx) = async_broadcast::broadcast(256);
277
278 let doc = Doc::new();
279 let map = doc.get_or_insert_map("messages");
280
281 let (yjs_tx, yjs_rx) = async_channel::unbounded::<(Vec<u8>, bool, bool, bool)>();
285 let yjs_sub = doc.observe_update_v2(move |txn, event: &UpdateEvent| {
286 let origin = txn.origin().map(|o| o.as_ref());
287 let is_remote = origin.map(|bytes| bytes == "remote".as_bytes()).unwrap_or(false);
288 let is_clear = origin.map(|bytes| bytes == "clear".as_bytes()).unwrap_or(false);
289 let is_load = origin.map(|bytes| bytes == "load".as_bytes()).unwrap_or(false);
290 let _ = yjs_tx.try_send((event.update.clone(), is_remote, is_clear, is_load));
291 });
292 if let Ok(sub) = yjs_sub {
293 std::mem::forget(sub);
294 }
295
296 let loop_fut = CoordinatorLoop {
297 node,
298 storage,
299 cmd_rx,
300 event_tx,
301 doc,
302 map,
303 yjs_rx,
304 hello_interval_secs,
305 };
306
307 (ChatCoordinatorHandle { cmd_tx }, event_rx, loop_fut)
308}
309
310pub(crate) fn broadcast_nonblocking(
313 event_tx: &BroadcastSender<CoordinatorEvent>,
314 event: CoordinatorEvent,
315) {
316 if let Err(e) = event_tx.try_broadcast(event) {
317 tracing::warn!("Failed to broadcast event: {}", e);
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use crate::network::mock::MockNetworkNode;
325 use crate::network::{HelloMessage, NetworkEvent};
326 use std::time::Duration;
327
328 fn in_memory_storage() -> SqliteStorage {
329 let conn = rusqlite::Connection::open_in_memory().unwrap();
330 let storage = SqliteStorage::new(conn);
331 storage.init_schema().unwrap();
332 storage
333 }
334
335 #[test]
336 fn hello_message_triggers_auto_dial() {
337 let rt = tokio::runtime::Runtime::new().unwrap();
338
339 let mut mock = MockNetworkNode::new("local-peer".to_string());
340 mock.set_dial_addrs(vec![]);
341 let dial_log = mock.dial_log.clone();
342 let injector = mock.event_tx.clone();
343 let storage = in_memory_storage();
344
345 let (handle, _events, loop_fut) = build_with_hello_interval(mock, storage, 3600);
346 rt.spawn(async move {
347 loop_fut.run().await;
348 });
349
350 rt.block_on(async {
351 handle.join_room("test-room".to_string()).await;
352 tokio::time::sleep(Duration::from_millis(100)).await;
353
354 let hello = HelloMessage {
355 peer_id: "remote-peer".to_string(),
356 circuit_address: Some(
357 "/dns4/relay.example.com/tcp/443/wss/p2p/relay-peer/p2p-circuit/p2p/remote-peer"
358 .to_string(),
359 ),
360 web_rtc_address: None,
361 };
362 let hello_bytes = serde_json::to_vec(&hello).unwrap();
363 let event = NetworkEvent::BroadcastReceived {
364 peer_id: PeerId::from("remote-peer"),
365 data: hello_bytes,
366 };
367 injector.send(event).await.ok();
368 tokio::time::sleep(Duration::from_millis(100)).await;
369
370 let dials = dial_log.lock().unwrap();
371 assert!(
372 dials.iter().any(|a| a.as_str().contains("relay.example.com")),
373 "Expected auto-dial of circuit address from hello message, got: {:?}",
374 *dials
375 );
376
377 let _ = handle.shutdown().await;
378 });
379 }
380
381 #[test]
382 fn hello_interval_publishes_hello_message() {
383 let rt = tokio::runtime::Runtime::new().unwrap();
384
385 let mut mock = MockNetworkNode::new("local-peer".to_string());
386 let circuit = "/dns4/relay.example.com/tcp/443/wss/p2p/relay-peer/p2p-circuit/p2p/local-peer"
387 .to_string();
388 mock.set_dial_addrs(vec![circuit.clone()]);
389 let published = mock.published_messages.clone();
390 let storage = in_memory_storage();
391
392 let (handle, _events, loop_fut) = build_with_hello_interval(mock, storage, 1);
393 rt.spawn(async move {
394 loop_fut.run().await;
395 });
396
397 rt.block_on(async {
398 handle.join_room("test-room".to_string()).await;
399 tokio::time::sleep(Duration::from_secs(2)).await;
401
402 let msgs = published.lock().unwrap();
403 assert!(
404 !msgs.is_empty(),
405 "Expected at least one hello message published"
406 );
407
408 let hello: HelloMessage = serde_json::from_slice(&msgs[0]).unwrap();
409 assert_eq!(hello.peer_id, "local-peer");
410 assert_eq!(hello.circuit_address, Some(circuit));
411 assert_eq!(hello.web_rtc_address, None);
412
413 let _ = handle.shutdown().await;
414 });
415 }
416}