scuffle_rtmp/chunk/
reader.rs

1//! Types and functions for reading RTMP chunks.
2
3use std::cmp::min;
4use std::collections::HashMap;
5use std::io::{self, Cursor, Seek, SeekFrom};
6
7use byteorder::{BigEndian, LittleEndian, ReadBytesExt};
8use bytes::BytesMut;
9use num_traits::FromPrimitive;
10
11use super::error::ChunkReadError;
12use super::{Chunk, ChunkBasicHeader, ChunkMessageHeader, ChunkType, INIT_CHUNK_SIZE, MAX_CHUNK_SIZE};
13use crate::messages::MessageType;
14
15// These constants are used to limit the amount of memory we use for partial
16// chunks on normal operations we should never hit these limits
17// This is for when someone is trying to send us a malicious chunk streams
18const MAX_PARTIAL_CHUNK_SIZE: usize = 10 * 1024 * 1024; // 10MB (should be more than enough)
19const MAX_PREVIOUS_CHUNK_HEADERS: usize = 100; // 100 chunks
20const MAX_PARTIAL_CHUNK_COUNT: usize = 4; // 4 chunks
21
22/// A chunk reader.
23///
24/// This is used to read chunks from a stream.
25pub struct ChunkReader {
26    /// According to the spec chunk streams are identified by the chunk stream
27    /// ID. In this case that is our key.
28    /// We then have a chunk header (since some chunks refer to the previous
29    /// chunk header)
30    previous_chunk_headers: HashMap<u32, ChunkMessageHeader>,
31
32    /// Technically according to the spec, we can have multiple message streams
33    /// in a single chunk stream. Because of this the key of this map is a tuple
34    /// (chunk stream id, message stream id).
35    partial_chunks: HashMap<(u32, u32), BytesMut>,
36
37    /// This is the max chunk size that the client has specified.
38    /// By default this is 128 bytes.
39    max_chunk_size: usize,
40}
41
42impl Default for ChunkReader {
43    fn default() -> Self {
44        Self {
45            previous_chunk_headers: HashMap::with_capacity(MAX_PREVIOUS_CHUNK_HEADERS),
46            partial_chunks: HashMap::with_capacity(MAX_PARTIAL_CHUNK_COUNT),
47            max_chunk_size: INIT_CHUNK_SIZE,
48        }
49    }
50}
51
52impl ChunkReader {
53    /// Call when a client requests a chunk size change.
54    ///
55    /// Returns `false` if the chunk size is out of bounds.
56    /// The connection should be closed in this case.
57    pub fn update_max_chunk_size(&mut self, chunk_size: usize) -> bool {
58        // We need to make sure that the chunk size is within the allowed range.
59        // Returning false here should close the connection.
60        if !(INIT_CHUNK_SIZE..=MAX_CHUNK_SIZE).contains(&chunk_size) {
61            false
62        } else {
63            self.max_chunk_size = chunk_size;
64            true
65        }
66    }
67
68    /// This function is used to read a chunk from the buffer.
69    ///
70    /// Returns:
71    /// - `Ok(None)` if the buffer does not contain enough data to read a full chunk.
72    /// - `Ok(Some(Chunk))` if a full chunk is read.
73    /// - `Err(ChunkReadError)` if there is an error decoding a chunk. The connection should be closed.
74    ///
75    /// # See also
76    ///
77    /// - [`Chunk`]
78    /// - [`ChunkReadError`]
79    pub fn read_chunk(&mut self, buffer: &mut BytesMut) -> Result<Option<Chunk>, crate::error::RtmpError> {
80        // We do this in a loop because we may have multiple chunks in the buffer,
81        // And those chunks may be partial chunks thus we need to keep reading until we
82        // have a full chunk or we run out of data.
83        loop {
84            // The cursor is an advanced cursor that is a reference to the buffer.
85            // This means the cursor does not advance the reader's position.
86            // Thus allowing us to backtrack if we need to read more data.
87            let mut cursor = std::io::Cursor::new(buffer.as_ref());
88
89            let header = match self.read_header(&mut cursor) {
90                Ok(header) => header,
91                Err(None) => {
92                    // Returning none here means that the buffer is empty and we need to wait for
93                    // more data.
94                    return Ok(None);
95                }
96                Err(Some(err)) => {
97                    // This is an error that we can't recover from, so we return it.
98                    // The connection will be closed.
99                    return Err(crate::error::RtmpError::Io(err));
100                }
101            };
102
103            let message_header = match self.read_message_header(&header, &mut cursor) {
104                Ok(message_header) => message_header,
105                Err(None) => {
106                    // Returning none here means that the buffer is empty and we need to wait for
107                    // more data.
108                    return Ok(None);
109                }
110                Err(Some(err)) => {
111                    // This is an error that we can't recover from, so we return it.
112                    // The connection will be closed.
113                    return Err(err);
114                }
115            };
116
117            let (payload_range_start, payload_range_end) =
118                match self.get_payload_range(&header, &message_header, &mut cursor) {
119                    Ok(data) => data,
120                    Err(None) => {
121                        // Returning none here means that the buffer is empty and we need to wait
122                        // for more data.
123                        return Ok(None);
124                    }
125                    Err(Some(err)) => {
126                        // This is an error that we can't recover from, so we return it.
127                        // The connection will be closed.
128                        return Err(err);
129                    }
130                };
131
132            // Since we were reading from an advanced cursor, our reads did not actually
133            // advance the reader's position. We need to manually advance the reader's
134            // position to the cursor's position.
135            let position = cursor.position() as usize;
136            if position > buffer.len() {
137                // In some cases we dont have enough data yet to read the chunk.
138                // We return Ok(None) here and the loop will continue.
139                return Ok(None);
140            }
141
142            let data = buffer.split_to(position);
143
144            // We freeze the chunk data and slice it to get the payload.
145            // Data before the slice is the header data, and data after the slice is the
146            // next chunk We don't need to keep the header data, because we already decoded
147            // it into struct form. The payload_range_end should be the same as the cursor's
148            // position.
149            let payload = data.freeze().slice(payload_range_start..payload_range_end);
150
151            // We need to check here if the chunk header is already stored in our map.
152            // This isnt a spec check but it is a check to make sure that we dont have too
153            // many previous chunk headers stored in memory.
154            let count = if self.previous_chunk_headers.contains_key(&header.chunk_stream_id) {
155                self.previous_chunk_headers.len()
156            } else {
157                self.previous_chunk_headers.len() + 1
158            };
159
160            // If this is hit, then we have too many previous chunk headers stored in
161            // memory. And the client is probably trying to DoS us.
162            // We return an error and the connection will be closed.
163            if count > MAX_PREVIOUS_CHUNK_HEADERS {
164                return Err(crate::error::RtmpError::ChunkRead(
165                    ChunkReadError::TooManyPreviousChunkHeaders,
166                ));
167            }
168
169            // We insert the chunk header into our map.
170            self.previous_chunk_headers
171                .insert(header.chunk_stream_id, message_header.clone());
172
173            // It is possible in theory to get a chunk message that requires us to change
174            // the max chunk size. However the size of that message is smaller than the
175            // default max chunk size. Therefore we can ignore this case.
176            // Since if we get such a message we will read it and the payload.len() will be
177            // equal to the message length. and thus we will return the chunk.
178
179            // Check if the payload is the same as the message length.
180            // If this is true we have a full chunk and we can return it.
181            if payload.len() == message_header.msg_length as usize {
182                return Ok(Some(Chunk {
183                    basic_header: header,
184                    message_header,
185                    payload,
186                }));
187            } else {
188                // Otherwise we generate a key using the chunk stream id and the message stream
189                // id. We then get the partial chunk from the map using the key.
190                let key = (header.chunk_stream_id, message_header.msg_stream_id);
191                let partial_chunk = match self.partial_chunks.get_mut(&key) {
192                    Some(partial_chunk) => partial_chunk,
193                    None => {
194                        // If it does not exists we create a new one.
195                        // If we have too many partial chunks we return an error.
196                        // Since the client is probably trying to DoS us.
197                        // The connection will be closed.
198                        if self.partial_chunks.len() >= MAX_PARTIAL_CHUNK_COUNT {
199                            return Err(crate::error::RtmpError::ChunkRead(ChunkReadError::TooManyPartialChunks));
200                        }
201
202                        // Insert a new empty BytesMut into the map.
203                        self.partial_chunks.insert(key, BytesMut::new());
204                        // Get the partial chunk we just inserted.
205                        self.partial_chunks.get_mut(&key).expect("we just inserted it")
206                    }
207                };
208
209                // We extend the partial chunk with the payload.
210                // And get the new length of the partial chunk.
211                let length = {
212                    // If the length of a single chunk is larger than the max partial chunk size
213                    // we return an error. The client is probably trying to DoS us.
214                    if partial_chunk.len() + payload.len() > MAX_PARTIAL_CHUNK_SIZE {
215                        return Err(crate::error::RtmpError::ChunkRead(ChunkReadError::PartialChunkTooLarge(
216                            partial_chunk.len() + payload.len(),
217                        )));
218                    }
219
220                    // Extend the partial chunk with the payload.
221                    partial_chunk.extend_from_slice(&payload[..]);
222
223                    // Return the new length of the partial chunk.
224                    partial_chunk.len()
225                };
226
227                // If we have a full chunk we return it.
228                if length == message_header.msg_length as usize {
229                    return Ok(Some(Chunk {
230                        basic_header: header,
231                        message_header,
232                        payload: self.partial_chunks.remove(&key).unwrap().freeze(),
233                    }));
234                }
235
236                // If we don't have a full chunk we just let the loop continue.
237                // Usually this will result in returning Ok(None) from one of
238                // the above checks. However there is a edge case that we have
239                // enough data in our buffer to read the next chunk and the
240                // client is waiting for us to send a response. Meaning if we
241                // just return Ok(None) here We would deadlock the connection,
242                // and it will eventually timeout. So we need to loop again here
243                // to check if we have enough data to read the next chunk.
244            }
245        }
246    }
247
248    /// Internal function used to read the basic chunk header.
249    fn read_header(&self, cursor: &mut Cursor<&[u8]>) -> Result<ChunkBasicHeader, Option<io::Error>> {
250        // The first byte of the basic header is the format of the chunk and the stream
251        // id. Mapping the error to none means that this isn't a real error but we dont
252        // have enough data.
253        let byte = cursor.read_u8().eof_to_none()?;
254        // The format is the first 2 bits of the byte. We shift the byte 6 bits to the
255        // right to get the format.
256        let format = (byte >> 6) & 0b00000011;
257
258        // We do not check that the format is valid.
259        // It should not be possible to get an invalid chunk type
260        // because, we bitshift the byte 6 bits to the right. Leaving 2 bits which can
261        // only be 0, 1 or 2 or 3 which is the only valid chunk types.
262        let format = ChunkType::from_u8(format).expect("unreachable");
263
264        // We then parse the chunk stream id.
265        let chunk_stream_id = match (byte & 0b00111111) as u32 {
266            // If the chunk stream id is 0 we read the next byte and add 64 to it.
267            0 => {
268                let first_byte = cursor.read_u8().eof_to_none()?;
269
270                64 + first_byte as u32
271            }
272            // If it is 1 we read the next 2 bytes and add 64 to it and multiply the 2nd byte by
273            // 256.
274            1 => {
275                let first_byte = cursor.read_u8().eof_to_none()?;
276                let second_byte = cursor.read_u8().eof_to_none()?;
277
278                64 + first_byte as u32 + second_byte as u32 * 256
279            }
280            // Any other value means that the chunk stream id is the value of the byte.
281            csid => csid,
282        };
283
284        // We then read the message header.
285        let header = ChunkBasicHeader { chunk_stream_id, format };
286
287        Ok(header)
288    }
289
290    /// Internal function used to read the message header.
291    fn read_message_header(
292        &self,
293        header: &ChunkBasicHeader,
294        cursor: &mut Cursor<&[u8]>,
295    ) -> Result<ChunkMessageHeader, Option<crate::error::RtmpError>> {
296        // Each format has a different message header length.
297        match header.format {
298            // Type0 headers have the most information and can be compared to keyframes in video.
299            // They do not reference any previous chunks. They contain the full message header.
300            ChunkType::Type0 => {
301                // The first 3 bytes are the timestamp.
302                let timestamp = cursor
303                    .read_u24::<BigEndian>()
304                    .eof_to_none()
305                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
306                // Followed by a 3 byte message length. (this is the length of the entire
307                // payload not just this chunk)
308                let msg_length = cursor
309                    .read_u24::<BigEndian>()
310                    .eof_to_none()
311                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
312                if msg_length as usize > MAX_PARTIAL_CHUNK_SIZE {
313                    return Err(Some(crate::error::RtmpError::ChunkRead(
314                        ChunkReadError::PartialChunkTooLarge(msg_length as usize),
315                    )));
316                }
317
318                // We then have a 1 byte message type id.
319                let msg_type_id = cursor
320                    .read_u8()
321                    .eof_to_none()
322                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
323                let msg_type_id = MessageType::from(msg_type_id);
324
325                // We then read the message stream id. (According to spec this is stored in
326                // LittleEndian, no idea why.)
327                let msg_stream_id = cursor
328                    .read_u32::<LittleEndian>()
329                    .eof_to_none()
330                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
331
332                // Sometimes the timestamp is larger than 3 bytes.
333                // If the timestamp is 0xFFFFFF we read the next 4 bytes as the timestamp.
334                // I am not exactly sure why they did it this way.
335                // Why not just use 3 bytes for the timestamp, and if the 3 bytes are set to
336                // 0xFFFFFF just read 1 additional byte and then shift it 24 bits.
337                // Like if timestamp == 0xFFFFFF { timestamp |= cursor.read_u8() << 24; }
338                // This would save 3 bytes in the header and would be more
339                // efficient but I guess the Spec writers are smarter than me.
340                let (timestamp, was_extended_timestamp) = if timestamp == 0xFFFFFF {
341                    (
342                        cursor
343                            .read_u32::<BigEndian>()
344                            .eof_to_none()
345                            .map_err(|e| e.map(crate::error::RtmpError::Io))?,
346                        true,
347                    )
348                } else {
349                    (timestamp, false)
350                };
351
352                Ok(ChunkMessageHeader {
353                    timestamp,
354                    msg_length,
355                    msg_type_id,
356                    msg_stream_id,
357                    was_extended_timestamp,
358                })
359            }
360            // For ChunkType 1 we have a delta timestamp, message length and message type id.
361            // The message stream id is the same as the previous chunk.
362            ChunkType::Type1 => {
363                // The first 3 bytes are the delta timestamp.
364                let timestamp_delta = cursor
365                    .read_u24::<BigEndian>()
366                    .eof_to_none()
367                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
368                // Followed by a 3 byte message length. (this is the length of the entire
369                // payload not just this chunk)
370                let msg_length = cursor
371                    .read_u24::<BigEndian>()
372                    .eof_to_none()
373                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
374                if msg_length as usize > MAX_PARTIAL_CHUNK_SIZE {
375                    return Err(Some(crate::error::RtmpError::ChunkRead(
376                        ChunkReadError::PartialChunkTooLarge(msg_length as usize),
377                    )));
378                }
379
380                // We then have a 1 byte message type id.
381                let msg_type_id = cursor
382                    .read_u8()
383                    .eof_to_none()
384                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
385                let msg_type_id = MessageType::from(msg_type_id);
386
387                // Again as mentioned above we sometimes have a delta timestamp larger than 3
388                // bytes.
389                let (timestamp_delta, was_extended_timestamp) = if timestamp_delta == 0xFFFFFF {
390                    (
391                        cursor
392                            .read_u32::<BigEndian>()
393                            .eof_to_none()
394                            .map_err(|e| e.map(crate::error::RtmpError::Io))?,
395                        true,
396                    )
397                } else {
398                    (timestamp_delta, false)
399                };
400
401                // We get the previous chunk header.
402                // If the previous chunk header is not found we return an error. (this is a real
403                // error)
404                let previous_header =
405                    self.previous_chunk_headers
406                        .get(&header.chunk_stream_id)
407                        .ok_or(crate::error::RtmpError::ChunkRead(
408                            ChunkReadError::MissingPreviousChunkHeader(header.chunk_stream_id),
409                        ))?;
410
411                // We calculate the timestamp by adding the delta timestamp to the previous
412                // timestamp. We need to make sure this does not overflow.
413                let timestamp = previous_header.timestamp.checked_add(timestamp_delta).unwrap_or_else(|| {
414                    tracing::warn!(
415						"Timestamp overflow detected. Previous timestamp: {}, delta timestamp: {}, using previous timestamp.",
416						previous_header.timestamp,
417						timestamp_delta
418					);
419
420                    previous_header.timestamp
421                });
422
423                Ok(ChunkMessageHeader {
424                    timestamp,
425                    msg_length,
426                    msg_type_id,
427                    was_extended_timestamp,
428                    // The message stream id is the same as the previous chunk.
429                    msg_stream_id: previous_header.msg_stream_id,
430                })
431            }
432            // ChunkType2 headers only have a delta timestamp.
433            // The message length, message type id and message stream id are the same as the
434            // previous chunk.
435            ChunkType::Type2 => {
436                // We read the delta timestamp.
437                let timestamp_delta = cursor
438                    .read_u24::<BigEndian>()
439                    .eof_to_none()
440                    .map_err(|e| e.map(crate::error::RtmpError::Io))?;
441
442                // Again if the delta timestamp is larger than 3 bytes we read the next 4 bytes
443                // as the timestamp.
444                let (timestamp_delta, was_extended_timestamp) = if timestamp_delta == 0xFFFFFF {
445                    (
446                        cursor
447                            .read_u32::<BigEndian>()
448                            .eof_to_none()
449                            .map_err(|e| e.map(crate::error::RtmpError::Io))?,
450                        true,
451                    )
452                } else {
453                    (timestamp_delta, false)
454                };
455
456                // We get the previous chunk header.
457                // If the previous chunk header is not found we return an error. (this is a real
458                // error)
459                let previous_header =
460                    self.previous_chunk_headers
461                        .get(&header.chunk_stream_id)
462                        .ok_or(crate::error::RtmpError::ChunkRead(
463                            ChunkReadError::MissingPreviousChunkHeader(header.chunk_stream_id),
464                        ))?;
465
466                // We calculate the timestamp by adding the delta timestamp to the previous
467                // timestamp.
468                let timestamp = previous_header.timestamp + timestamp_delta;
469
470                Ok(ChunkMessageHeader {
471                    timestamp,
472                    msg_length: previous_header.msg_length,
473                    msg_type_id: previous_header.msg_type_id,
474                    msg_stream_id: previous_header.msg_stream_id,
475                    was_extended_timestamp,
476                })
477            }
478            // ChunkType3 headers are the same as the previous chunk header.
479            ChunkType::Type3 => {
480                // We get the previous chunk header.
481                // If the previous chunk header is not found we return an error. (this is a real
482                // error)
483                let previous_header = self
484                    .previous_chunk_headers
485                    .get(&header.chunk_stream_id)
486                    .ok_or(crate::error::RtmpError::ChunkRead(
487                        ChunkReadError::MissingPreviousChunkHeader(header.chunk_stream_id),
488                    ))?
489                    .clone();
490
491                // Now this is truely stupid.
492                // If the PREVIOUS HEADER is extended then we now waste an additional 4 bytes to
493                // read the timestamp. Why not just read the timestamp in the previous header if
494                // it is extended? I guess the spec writers had some reason and its obviously
495                // way above my knowledge.
496                if previous_header.was_extended_timestamp {
497                    // Not a real error, we just dont have enough data.
498                    // We dont have to store this value since it is the same as the previous header.
499                    cursor
500                        .read_u32::<BigEndian>()
501                        .eof_to_none()
502                        .map_err(|e| e.map(crate::error::RtmpError::Io))?;
503                }
504
505                Ok(previous_header)
506            }
507        }
508    }
509
510    /// Internal function to get the payload range of a chunk.
511    fn get_payload_range(
512        &self,
513        header: &ChunkBasicHeader,
514        message_header: &ChunkMessageHeader,
515        cursor: &mut Cursor<&'_ [u8]>,
516    ) -> Result<(usize, usize), Option<crate::error::RtmpError>> {
517        // We find out if the chunk is a partial chunk (and if we have already read some
518        // of it).
519        let key = (header.chunk_stream_id, message_header.msg_stream_id);
520
521        // Check how much we still need to read (if we have already read some of the
522        // chunk)
523        let remaining_read_length =
524            message_header.msg_length as usize - self.partial_chunks.get(&key).map(|data| data.len()).unwrap_or(0);
525
526        // We get the min between our max chunk size and the remaining read length.
527        // This is the amount of bytes we need to read.
528        let need_read_length = min(remaining_read_length, self.max_chunk_size);
529
530        // We get the current position in the cursor.
531        let pos = cursor.position() as usize;
532
533        // We seek forward to where the payload starts.
534        cursor
535            .seek(SeekFrom::Current(need_read_length as i64))
536            .eof_to_none()
537            .map_err(|e| e.map(crate::error::RtmpError::Io))?;
538
539        // We then return the range of the payload.
540        // Which would be the pos to the pos + need_read_length.
541        Ok((pos, pos + need_read_length))
542    }
543}
544
545trait IoResultExt<T> {
546    fn eof_to_none(self) -> Result<T, Option<io::Error>>;
547}
548
549impl<T> IoResultExt<T> for io::Result<T> {
550    fn eof_to_none(self) -> Result<T, Option<io::Error>> {
551        self.map_err(|e| {
552            if e.kind() == io::ErrorKind::UnexpectedEof {
553                None
554            } else {
555                Some(e)
556            }
557        })
558    }
559}
560
561#[cfg(test)]
562#[cfg_attr(all(test, coverage_nightly), coverage(off))]
563mod tests {
564    use byteorder::WriteBytesExt;
565    use bytes::{BufMut, BytesMut};
566
567    use super::*;
568
569    #[test]
570    fn test_reader_error_display() {
571        let error = ChunkReadError::MissingPreviousChunkHeader(123);
572        assert_eq!(format!("{error}"), "missing previous chunk header: 123");
573
574        let error = ChunkReadError::TooManyPartialChunks;
575        assert_eq!(format!("{error}"), "too many partial chunks");
576
577        let error = ChunkReadError::TooManyPreviousChunkHeaders;
578        assert_eq!(format!("{error}"), "too many previous chunk headers");
579
580        let error = ChunkReadError::PartialChunkTooLarge(100);
581        assert_eq!(format!("{error}"), "partial chunk too large: 100");
582    }
583
584    #[test]
585    fn test_reader_chunk_size_out_of_bounds() {
586        let mut reader = ChunkReader::default();
587        assert!(!reader.update_max_chunk_size(MAX_CHUNK_SIZE + 1));
588    }
589
590    #[test]
591    fn test_incomplete_header() {
592        let mut buf = BytesMut::new();
593        buf.extend_from_slice(&[0b00_000000]);
594
595        let reader = ChunkReader::default();
596        let err = reader.read_header(&mut Cursor::new(&buf));
597        assert!(matches!(err, Err(None)));
598    }
599
600    #[test]
601    fn test_reader_chunk_type0_single_sized() {
602        let mut buf = BytesMut::new();
603
604        #[rustfmt::skip]
605        buf.extend_from_slice(&[
606            3, // chunk type 0, chunk stream id 3
607            0x00, 0x00, 0x00, // timestamp
608            0x00, 0x00, 0x80, // message length (128) (max chunk size is set to 128)
609            0x09, // message type id (video)
610            0x00, 0x01, 0x00, 0x00, // message stream id
611        ]);
612
613        for i in 0..128 {
614            (&mut buf).writer().write_u8(i as u8).unwrap();
615        }
616
617        let mut unpacker = ChunkReader::default();
618        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
619        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
620        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
621        assert_eq!(chunk.message_header.timestamp, 0);
622        assert_eq!(chunk.message_header.msg_length, 128);
623        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
624        assert_eq!(chunk.payload.len(), 128);
625    }
626
627    #[test]
628    fn test_reader_chunk_type0_double_sized() {
629        let mut buf = BytesMut::new();
630        #[rustfmt::skip]
631        buf.extend_from_slice(&[
632            3, // chunk type 0, chunk stream id 3
633            0x00, 0x00, 0x00, // timestamp
634            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
635            0x09, // message type id (video)
636            0x00, 0x01, 0x00, 0x00, // message stream id
637        ]);
638
639        for i in 0..128 {
640            (&mut buf).writer().write_u8(i as u8).unwrap();
641        }
642
643        let mut unpacker = ChunkReader::default();
644
645        let chunk = buf.as_ref().to_vec();
646
647        // We should not have enough data to read the chunk
648        // But the chunk is valid, so we should not get an error
649        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
650
651        // We just feed the same data again in this test to see if the Unpacker merges
652        // the chunks Which it should do
653        buf.extend_from_slice(&chunk);
654
655        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
656
657        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
658        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
659        assert_eq!(chunk.message_header.timestamp, 0);
660        assert_eq!(chunk.message_header.msg_length, 256);
661        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
662        assert_eq!(chunk.payload.len(), 256);
663    }
664
665    #[test]
666    fn test_reader_chunk_mutli_streams() {
667        let mut buf = BytesMut::new();
668
669        #[rustfmt::skip]
670        buf.extend_from_slice(&[
671            3, // chunk type 0, chunk stream id 3
672            0x00, 0x00, 0x00, // timestamp
673            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
674            0x09, // message type id (video)
675            0x00, 0x01, 0x00, 0x00, // message stream id
676        ]);
677
678        for _ in 0..128 {
679            (&mut buf).writer().write_u8(3).unwrap();
680        }
681
682        #[rustfmt::skip]
683        buf.extend_from_slice(&[
684            4, // chunk type 0, chunk stream id 4 (different stream)
685            0x00, 0x00, 0x00, // timestamp
686            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
687            0x08, // message type id (audio)
688            0x00, 0x03, 0x00, 0x00, // message stream id
689        ]);
690
691        for _ in 0..128 {
692            (&mut buf).writer().write_u8(4).unwrap();
693        }
694
695        let mut unpacker = ChunkReader::default();
696
697        // We wrote 2 chunks but neither of them are complete
698        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
699
700        #[rustfmt::skip]
701        buf.extend_from_slice(&[
702            (3 << 6) | 4, // chunk type 3, chunk stream id 4
703        ]);
704
705        for _ in 0..128 {
706            (&mut buf).writer().write_u8(3).unwrap();
707        }
708
709        // Even though we wrote chunk 3 first, chunk 4 should be read first since it's a
710        // different stream
711        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
712
713        assert_eq!(chunk.basic_header.chunk_stream_id, 4);
714        assert_eq!(chunk.message_header.msg_type_id.0, 0x08);
715        assert_eq!(chunk.message_header.timestamp, 0);
716        assert_eq!(chunk.message_header.msg_length, 256);
717        assert_eq!(chunk.message_header.msg_stream_id, 0x0300); // since it's little endian, it's 0x0100
718        assert_eq!(chunk.payload.len(), 256);
719        for i in 0..128 {
720            assert_eq!(chunk.payload[i], 4);
721        }
722
723        // No chunk is ready yet
724        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
725
726        #[rustfmt::skip]
727        buf.extend_from_slice(&[
728            (3 << 6) | 3, // chunk type 3, chunk stream id 3
729        ]);
730
731        for _ in 0..128 {
732            (&mut buf).writer().write_u8(3).unwrap();
733        }
734
735        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
736
737        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
738        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
739        assert_eq!(chunk.message_header.timestamp, 0);
740        assert_eq!(chunk.message_header.msg_length, 256);
741        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
742        assert_eq!(chunk.payload.len(), 256);
743        for i in 0..128 {
744            assert_eq!(chunk.payload[i], 3);
745        }
746    }
747
748    #[test]
749    fn test_reader_extended_timestamp() {
750        let mut buf = BytesMut::new();
751
752        #[rustfmt::skip]
753        buf.extend_from_slice(&[
754            3, // chunk type 0, chunk stream id 3
755            0xFF, 0xFF, 0xFF, // timestamp
756            0x00, 0x02, 0x00, // message length (384) (max chunk size is set to 128)
757            0x09, // message type id (video)
758            0x00, 0x01, 0x00, 0x00, // message stream id
759            0x01, 0x00, 0x00, 0x00, // extended timestamp
760        ]);
761
762        for i in 0..128 {
763            (&mut buf).writer().write_u8(i as u8).unwrap();
764        }
765
766        let mut unpacker = ChunkReader::default();
767
768        // We should not have enough data to read the chunk
769        // But the chunk is valid, so we should not get an error
770        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
771
772        #[rustfmt::skip]
773        buf.extend_from_slice(&[
774            (1 << 6) | 3, // chunk type 1, chunk stream id 3
775            0xFF, 0xFF, 0xFF, // extended timestamp (again)
776            0x00, 0x02, 0x00, // message length (384) (max chunk size is set to 128)
777            0x09, // message type id (video)
778            // message stream id is not present since it's the same as the previous chunk
779            0x01, 0x00, 0x00, 0x00, // extended timestamp (again)
780        ]);
781
782        for i in 0..128 {
783            (&mut buf).writer().write_u8(i as u8).unwrap();
784        }
785
786        #[rustfmt::skip]
787        buf.extend_from_slice(&[
788            (2 << 6) | 3, // chunk type 3, chunk stream id 3
789            0x00, 0x00, 0x01, // not extended timestamp
790        ]);
791
792        for i in 0..128 {
793            (&mut buf).writer().write_u8(i as u8).unwrap();
794        }
795
796        #[rustfmt::skip]
797        buf.extend_from_slice(&[
798            (3 << 6) | 3, // chunk type 3, chunk stream id 3
799        ]);
800
801        for i in 0..128 {
802            (&mut buf).writer().write_u8(i as u8).unwrap();
803        }
804
805        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
806
807        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
808        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
809        assert_eq!(chunk.message_header.timestamp, 0x02000001);
810        assert_eq!(chunk.message_header.msg_length, 512);
811        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
812        assert_eq!(chunk.payload.len(), 512);
813    }
814
815    #[test]
816    fn test_reader_extended_timestamp_ext() {
817        let mut buf = BytesMut::new();
818
819        #[rustfmt::skip]
820        buf.extend_from_slice(&[
821            3, // chunk type 0, chunk stream id 3
822            0xFF, 0xFF, 0xFF, // timestamp
823            0x00, 0x01, 0x00, // message length (256) (max chunk size is set to 128)
824            0x09, // message type id (video)
825            0x00, 0x01, 0x00, 0x00, // message stream id
826            0x01, 0x00, 0x00, 0x00, // extended timestamp
827        ]);
828
829        for i in 0..128 {
830            (&mut buf).writer().write_u8(i as u8).unwrap();
831        }
832
833        let mut unpacker = ChunkReader::default();
834
835        // We should not have enough data to read the chunk
836        // But the chunk is valid, so we should not get an error
837        assert!(unpacker.read_chunk(&mut buf).expect("read chunk").is_none());
838
839        #[rustfmt::skip]
840        buf.extend_from_slice(&[
841            (3 << 6) | 3, // chunk type 1, chunk stream id 3
842            0x00, 0x00, 0x00, 0x00, // extended timestamp this value is ignored
843        ]);
844
845        for i in 0..128 {
846            (&mut buf).writer().write_u8(i as u8).unwrap();
847        }
848
849        for i in 0..128 {
850            (&mut buf).writer().write_u8(i as u8).unwrap();
851        }
852
853        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
854
855        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
856        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
857        assert_eq!(chunk.message_header.timestamp, 0x01000000);
858        assert_eq!(chunk.message_header.msg_length, 256);
859        assert_eq!(chunk.message_header.msg_stream_id, 0x0100); // since it's little endian, it's 0x0100
860        assert_eq!(chunk.payload.len(), 256);
861    }
862
863    #[test]
864    fn test_read_extended_csid() {
865        let mut buf = BytesMut::new();
866
867        #[rustfmt::skip]
868        buf.extend_from_slice(&[
869            (0 << 6), // chunk type 0, chunk stream id 0
870            10,       // extended chunk stream id
871            0x00, 0x00, 0x00, // timestamp
872            0x00, 0x00, 0x00, // message length (256) (max chunk size is set to 128)
873            0x09, // message type id (video)
874            0x00, 0x01, 0x00, 0x00, // message stream id
875        ]);
876
877        let mut unpacker = ChunkReader::default();
878        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
879
880        assert_eq!(chunk.basic_header.chunk_stream_id, 64 + 10);
881    }
882
883    #[test]
884    fn test_read_extended_csid_ext2() {
885        let mut buf = BytesMut::new();
886
887        #[rustfmt::skip]
888        buf.extend_from_slice(&[
889            1,  // chunk type 0, chunk stream id 0
890            10, // extended chunk stream id
891            13, // extended chunk stream id 2
892            0x00, 0x00, 0x00, // timestamp
893            0x00, 0x00, 0x00, // message length (256) (max chunk size is set to 128)
894            0x09, // message type id (video)
895            0x00, 0x01, 0x00, 0x00, // message stream id
896        ]);
897
898        let mut unpacker = ChunkReader::default();
899
900        let chunk = unpacker.read_chunk(&mut buf).expect("read chunk").expect("chunk");
901
902        assert_eq!(chunk.basic_header.chunk_stream_id, 64 + 10 + 256 * 13);
903    }
904
905    #[test]
906    fn test_reader_error_no_previous_chunk() {
907        let mut buf = BytesMut::new();
908
909        // Write a chunk with type 3 but no previous chunk
910        #[rustfmt::skip]
911        buf.extend_from_slice(&[
912            (3 << 6) | 3, // chunk type 0, chunk stream id 3
913        ]);
914
915        let mut unpacker = ChunkReader::default();
916        let err = unpacker.read_chunk(&mut buf).unwrap_err();
917        match err {
918            crate::error::RtmpError::ChunkRead(ChunkReadError::MissingPreviousChunkHeader(3)) => {}
919            _ => panic!("Unexpected error: {err:?}"),
920        }
921    }
922
923    #[test]
924    fn test_reader_error_partial_chunk_too_large() {
925        let mut buf = BytesMut::new();
926
927        // Write a chunk that has a message size that is too large
928        #[rustfmt::skip]
929        buf.extend_from_slice(&[
930            3, // chunk type 0, chunk stream id 3
931            0xFF, 0xFF, 0xFF, // timestamp
932            0xFF, 0xFF, 0xFF, // message length (max chunk size is set to 128)
933            0x09, // message type id (video)
934            0x00, 0x01, 0x00, 0x00, // message stream id
935            0x01, 0x00, 0x00, 0x00, // extended timestamp
936        ]);
937
938        let mut unpacker = ChunkReader::default();
939
940        let err = unpacker.read_chunk(&mut buf).unwrap_err();
941        match err {
942            crate::error::RtmpError::ChunkRead(ChunkReadError::PartialChunkTooLarge(16777215)) => {}
943            _ => panic!("Unexpected error: {err:?}"),
944        }
945    }
946
947    #[test]
948    fn test_reader_error_too_many_partial_chunks() {
949        let mut buf = BytesMut::new();
950
951        let mut unpacker = ChunkReader::default();
952
953        for i in 0..4 {
954            // Write another chunk with a different chunk stream id
955            #[rustfmt::skip]
956            buf.extend_from_slice(&[
957                (i + 2), // chunk type 0 (partial), chunk stream id i
958                0xFF, 0xFF, 0xFF, // timestamp
959                0x00, 0x01, 0x00, // message length (max chunk size is set to 128)
960                0x09, // message type id (video)
961                0x00, 0x01, 0x00, 0x00, // message stream id
962                0x01, 0x00, 0x00, 0x00, // extended timestamp
963            ]);
964
965            for i in 0..128 {
966                (&mut buf).writer().write_u8(i as u8).unwrap();
967            }
968
969            // Read the chunk
970            assert!(
971                unpacker
972                    .read_chunk(&mut buf)
973                    .unwrap_or_else(|_| panic!("chunk failed {i}"))
974                    .is_none()
975            );
976        }
977
978        // Write another chunk with a different chunk stream id
979        #[rustfmt::skip]
980        buf.extend_from_slice(&[
981            12, // chunk type 0, chunk stream id 6
982            0xFF, 0xFF, 0xFF, // timestamp
983            0x00, 0x01, 0x00, // message length (max chunk size is set to 128)
984            0x09, // message type id (video)
985            0x00, 0x01, 0x00, 0x00, // message stream id
986            0x01, 0x00, 0x00, 0x00, // extended timestamp
987        ]);
988
989        for i in 0..128 {
990            (&mut buf).writer().write_u8(i as u8).unwrap();
991        }
992
993        let err = unpacker.read_chunk(&mut buf).unwrap_err();
994        match err {
995            crate::error::RtmpError::ChunkRead(ChunkReadError::TooManyPartialChunks) => {}
996            _ => panic!("Unexpected error: {err:?}"),
997        }
998    }
999
1000    #[test]
1001    fn test_reader_error_too_many_chunk_headers() {
1002        let mut buf = BytesMut::new();
1003
1004        let mut unpacker = ChunkReader::default();
1005
1006        for i in 0..100 {
1007            // Write another chunk with a different chunk stream id
1008            #[rustfmt::skip]
1009            buf.extend_from_slice(&[
1010                (0 << 6), // chunk type 0 (partial), chunk stream id 0
1011                i,        // chunk id
1012                0xFF, 0xFF, 0xFF, // timestamp
1013                0x00, 0x00, 0x00, // message length (max chunk size is set to 128)
1014                0x09, // message type id (video)
1015                0x00, 0x01, 0x00, 0x00, // message stream id
1016                0x01, 0x00, 0x00, 0x00, // extended timestamp
1017            ]);
1018
1019            // Read the chunk (should be a full chunk since the message length is 0)
1020            assert!(
1021                unpacker
1022                    .read_chunk(&mut buf)
1023                    .unwrap_or_else(|_| panic!("chunk failed {i}"))
1024                    .is_some()
1025            );
1026        }
1027
1028        // Write another chunk with a different chunk stream id
1029        #[rustfmt::skip]
1030        buf.extend_from_slice(&[
1031            12, // chunk type 0, chunk stream id 6
1032            0xFF, 0xFF, 0xFF, // timestamp
1033            0x00, 0x00, 0x00, // message length (max chunk size is set to 128)
1034            0x09, // message type id (video)
1035            0x00, 0x01, 0x00, 0x00, // message stream id
1036            0x01, 0x00, 0x00, 0x00, // extended timestamp
1037        ]);
1038
1039        let err = unpacker.read_chunk(&mut buf).unwrap_err();
1040        match err {
1041            crate::error::RtmpError::ChunkRead(ChunkReadError::TooManyPreviousChunkHeaders) => {}
1042            _ => panic!("Unexpected error: {err:?}"),
1043        }
1044    }
1045
1046    #[test]
1047    fn test_reader_larger_chunk_size() {
1048        let mut buf = BytesMut::new();
1049
1050        // Write a chunk that has a message size that is too large
1051        #[rustfmt::skip]
1052        buf.extend_from_slice(&[
1053            3, // chunk type 0, chunk stream id 3
1054            0x00, 0x00, 0xFF, // timestamp
1055            0x00, 0x0F, 0x00, // message length ()
1056            0x09, // message type id (video)
1057            0x01, 0x00, 0x00, 0x00, // message stream id
1058        ]);
1059
1060        for i in 0..3840 {
1061            (&mut buf).writer().write_u8(i as u8).unwrap();
1062        }
1063
1064        let mut unpacker = ChunkReader::default();
1065        unpacker.update_max_chunk_size(4096);
1066
1067        let chunk = unpacker.read_chunk(&mut buf).expect("failed").expect("chunk");
1068        assert_eq!(chunk.basic_header.chunk_stream_id, 3);
1069        assert_eq!(chunk.message_header.timestamp, 255);
1070        assert_eq!(chunk.message_header.msg_length, 3840);
1071        assert_eq!(chunk.message_header.msg_type_id.0, 0x09);
1072        assert_eq!(chunk.message_header.msg_stream_id, 1); // little endian
1073        assert_eq!(chunk.payload.len(), 3840);
1074
1075        for i in 0..3840 {
1076            assert_eq!(chunk.payload[i], i as u8);
1077        }
1078    }
1079}