1use std::io;
2
3use scuffle_bytes_util::BitReader;
4use utils::read_leb128;
5
6pub mod seq;
7mod utils;
8
9#[derive(Debug, Clone, PartialEq, Eq, Copy)]
12pub struct ObuHeader {
13 pub obu_type: ObuType,
17 pub size: Option<u64>,
21 pub extension_header: Option<ObuExtensionHeader>,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq, Copy)]
28pub struct ObuExtensionHeader {
29 pub temporal_id: u8,
31 pub spatial_id: u8,
33}
34
35impl ObuHeader {
36 pub fn parse(cursor: &mut impl io::Read) -> io::Result<Self> {
38 let mut bit_reader = BitReader::new(cursor);
39 let forbidden_bit = bit_reader.read_bit()?;
40 if forbidden_bit {
41 return Err(io::Error::new(io::ErrorKind::InvalidData, "obu_forbidden_bit is not 0"));
42 }
43
44 let obu_type = bit_reader.read_bits(4)?;
45 let extension_flag = bit_reader.read_bit()?;
46 let has_size_field = bit_reader.read_bit()?;
47
48 bit_reader.read_bit()?; let extension_header = if extension_flag {
51 let temporal_id = bit_reader.read_bits(3)?;
52 let spatial_id = bit_reader.read_bits(2)?;
53 bit_reader.read_bits(3)?; Some(ObuExtensionHeader {
55 temporal_id: temporal_id as u8,
56 spatial_id: spatial_id as u8,
57 })
58 } else {
59 None
60 };
61
62 let size = if has_size_field {
63 Some(read_leb128(&mut bit_reader)?)
65 } else {
66 None
67 };
68
69 if !bit_reader.is_aligned() {
70 return Err(io::Error::new(io::ErrorKind::InvalidData, "bit reader is not aligned"));
71 }
72
73 Ok(ObuHeader {
74 obu_type: ObuType::from(obu_type as u8),
75 size,
76 extension_header,
77 })
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Copy)]
84pub enum ObuType {
85 SequenceHeader,
87 TemporalDelimiter,
89 FrameHeader,
91 TileGroup,
93 Metadata,
95 Frame,
97 RedundantFrameHeader,
99 TileList,
101 Padding,
103 Reserved(u8),
105}
106
107impl From<u8> for ObuType {
108 fn from(value: u8) -> Self {
109 match value {
110 1 => ObuType::SequenceHeader,
111 2 => ObuType::TemporalDelimiter,
112 3 => ObuType::FrameHeader,
113 4 => ObuType::TileGroup,
114 5 => ObuType::Metadata,
115 6 => ObuType::Frame,
116 7 => ObuType::RedundantFrameHeader,
117 8 => ObuType::TileList,
118 15 => ObuType::Padding,
119 _ => ObuType::Reserved(value),
120 }
121 }
122}
123
124impl From<ObuType> for u8 {
125 fn from(value: ObuType) -> Self {
126 match value {
127 ObuType::SequenceHeader => 1,
128 ObuType::TemporalDelimiter => 2,
129 ObuType::FrameHeader => 3,
130 ObuType::TileGroup => 4,
131 ObuType::Metadata => 5,
132 ObuType::Frame => 6,
133 ObuType::RedundantFrameHeader => 7,
134 ObuType::TileList => 8,
135 ObuType::Padding => 15,
136 ObuType::Reserved(value) => value,
137 }
138 }
139}
140
141#[cfg(test)]
142#[cfg_attr(all(coverage_nightly, test), coverage(off))]
143mod tests {
144 use bytes::Buf;
145
146 use super::*;
147
148 #[test]
149 fn test_obu_header_parse() {
150 let mut cursor = std::io::Cursor::new(b"\n\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@");
151 let header = ObuHeader::parse(&mut cursor).unwrap();
152 insta::assert_debug_snapshot!(header, @r"
153 ObuHeader {
154 obu_type: SequenceHeader,
155 size: Some(
156 15,
157 ),
158 extension_header: None,
159 }
160 ");
161
162 assert_eq!(cursor.position(), 2);
163 assert_eq!(cursor.remaining(), 15);
164 }
165
166 #[test]
167 fn test_obu_header_parse_no_size_field() {
168 let mut cursor = std::io::Cursor::new(b"\x00");
169 let header = ObuHeader::parse(&mut cursor).unwrap();
170 insta::assert_debug_snapshot!(header, @r"
171 ObuHeader {
172 obu_type: Reserved(
173 0,
174 ),
175 size: None,
176 extension_header: None,
177 }
178 ");
179
180 assert_eq!(cursor.position(), 1);
181 assert_eq!(cursor.remaining(), 0);
182 }
183
184 #[test]
185 fn test_obu_header_parse_extension_header() {
186 let mut cursor = std::io::Cursor::new([0b00000100, 0b11010000]);
187 let header = ObuHeader::parse(&mut cursor).unwrap();
188 insta::assert_debug_snapshot!(header, @r"
189 ObuHeader {
190 obu_type: Reserved(
191 0,
192 ),
193 size: None,
194 extension_header: Some(
195 ObuExtensionHeader {
196 temporal_id: 6,
197 spatial_id: 2,
198 },
199 ),
200 }
201 ");
202
203 assert_eq!(cursor.position(), 2);
204 assert_eq!(cursor.remaining(), 0);
205 }
206
207 #[test]
208 fn test_obu_header_forbidden_bit_set() {
209 let err = ObuHeader::parse(&mut std::io::Cursor::new(
210 b"\xff\x0f\0\0\0j\xef\xbf\xe1\xbc\x02\x19\x90\x10\x10\x10@",
211 ))
212 .unwrap_err();
213 insta::assert_debug_snapshot!(err, @r#"
214 Custom {
215 kind: InvalidData,
216 error: "obu_forbidden_bit is not 0",
217 }
218 "#);
219 }
220
221 #[test]
222 fn test_obu_to_from_u8() {
223 let case = [
224 (ObuType::SequenceHeader, 1),
225 (ObuType::TemporalDelimiter, 2),
226 (ObuType::FrameHeader, 3),
227 (ObuType::TileGroup, 4),
228 (ObuType::Metadata, 5),
229 (ObuType::Frame, 6),
230 (ObuType::RedundantFrameHeader, 7),
231 (ObuType::TileList, 8),
232 (ObuType::Padding, 15),
233 (ObuType::Reserved(0), 0),
234 (ObuType::Reserved(100), 100),
235 ];
236
237 for (obu_type, value) in case {
238 assert_eq!(u8::from(obu_type), value);
239 assert_eq!(ObuType::from(value), obu_type);
240 }
241 }
242}