network_libp2p/
protocol.rs1use std::future::Future;
2
3use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use libp2p::request_response::Codec;
5use std::io;
6
7#[derive(Debug, Clone, Default)]
13pub struct RawBytesCodec;
14
15impl Codec for RawBytesCodec {
16 type Protocol = &'static str;
17 type Request = Vec<u8>;
18 type Response = Vec<u8>;
19
20 fn read_request<T>(
21 &mut self,
22 _protocol: &Self::Protocol,
23 io: &mut T,
24 ) -> impl Future<Output = io::Result<Self::Request>> + Send
25 where
26 T: AsyncRead + Unpin + Send,
27 {
28 read_bytes_frame(io)
29 }
30
31 fn read_response<T>(
32 &mut self,
33 _protocol: &Self::Protocol,
34 io: &mut T,
35 ) -> impl Future<Output = io::Result<Self::Response>> + Send
36 where
37 T: AsyncRead + Unpin + Send,
38 {
39 read_bytes_frame(io)
40 }
41
42 fn write_request<T>(
43 &mut self,
44 _protocol: &Self::Protocol,
45 io: &mut T,
46 req: Self::Request,
47 ) -> impl Future<Output = io::Result<()>> + Send
48 where
49 T: AsyncWrite + Unpin + Send,
50 {
51 async move {
52 write_bytes_frame(io, &req).await
53 }
54 }
55
56 fn write_response<T>(
57 &mut self,
58 _protocol: &Self::Protocol,
59 io: &mut T,
60 resp: Self::Response,
61 ) -> impl Future<Output = io::Result<()>> + Send
62 where
63 T: AsyncWrite + Unpin + Send,
64 {
65 async move {
66 write_bytes_frame(io, &resp).await
67 }
68 }
69}
70
71async fn read_bytes_frame<T>(io: &mut T) -> io::Result<Vec<u8>>
72where
73 T: AsyncRead + Unpin + Send,
74{
75 let mut length_bytes = [0u8; 4];
76 io.read_exact(&mut length_bytes).await?;
77 let length = u32::from_be_bytes(length_bytes) as usize;
78
79 if length > 16 * 1024 * 1024 {
80 return Err(io::Error::new(
81 io::ErrorKind::InvalidData,
82 "frame exceeds 16 MiB",
83 ));
84 }
85
86 let mut buf = vec![0u8; length];
87 io.read_exact(&mut buf).await?;
88 Ok(buf)
89}
90
91async fn write_bytes_frame<T>(io: &mut T, data: &[u8]) -> io::Result<()>
92where
93 T: AsyncWrite + Unpin + Send,
94{
95 let length = data.len() as u32;
96 io.write_all(&length.to_be_bytes()).await?;
97 io.write_all(data).await?;
98 io.flush().await?;
99 Ok(())
100}
101
102pub fn history_behaviour() -> libp2p::request_response::Behaviour<RawBytesCodec> {
104 libp2p::request_response::Behaviour::new(
105 [(chat_core::config::HISTORY_PROTOCOL, libp2p::request_response::ProtocolSupport::Full)],
106 libp2p::request_response::Config::default(),
107 )
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use futures::executor::block_on;
114 use futures::io::Cursor;
115
116 #[test]
117 fn roundtrip_raw_bytes() {
118 let req = b"hello world".to_vec();
119 let mut buf = Vec::new();
120 let mut codec = RawBytesCodec;
121 block_on(codec.write_request(&"/test", &mut buf, req.clone())).unwrap();
122
123 let mut cursor = Cursor::new(&buf);
124 let decoded: Vec<u8> = block_on(codec.read_request(&"/test", &mut cursor)).unwrap();
125 assert_eq!(decoded, req);
126 }
127
128 #[test]
129 fn frame_includes_big_endian_length_prefix() {
130 let req = b"abc".to_vec();
131 let mut buf = Vec::new();
132 let mut codec = RawBytesCodec;
133 block_on(codec.write_request(&"/test", &mut buf, req)).unwrap();
134
135 assert!(buf.len() >= 4);
136 let length = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
137 assert_eq!(length, buf.len() - 4);
138 }
139
140 #[test]
141 fn rejects_oversized_frame() {
142 let buf = vec![0xff; 4]; let mut cursor = Cursor::new(&buf);
144 let mut codec = RawBytesCodec;
145 let result: io::Result<Vec<u8>> = block_on(codec.read_request(&"/test", &mut cursor));
146 assert!(result.is_err());
147 }
148}