scuffle_bytes_util/
bit_write.rs1use std::io;
2
3#[derive(Debug)]
5#[must_use]
6pub struct BitWriter<W> {
7 bit_pos: u8,
8 current_byte: u8,
9 writer: W,
10}
11
12impl<W: Default> Default for BitWriter<W> {
13 fn default() -> Self {
14 Self::new(W::default())
15 }
16}
17
18impl<W: io::Write> BitWriter<W> {
19 pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
21 if bit {
22 self.current_byte |= 1 << (7 - self.bit_pos);
23 } else {
24 self.current_byte &= !(1 << (7 - self.bit_pos));
25 }
26
27 self.bit_pos += 1;
28
29 if self.bit_pos == 8 {
30 self.writer.write_all(&[self.current_byte])?;
31 self.current_byte = 0;
32 self.bit_pos = 0;
33 }
34
35 Ok(())
36 }
37
38 pub fn write_bits(&mut self, bits: u64, count: u8) -> io::Result<()> {
41 let count = count.min(64);
42
43 if count != 64 && bits > (1 << count as u64) - 1 {
44 return Err(io::Error::new(io::ErrorKind::InvalidData, "bits too large to write"));
45 }
46
47 for i in 0..count {
48 let bit = (bits >> (count - i - 1)) & 1 == 1;
49 self.write_bit(bit)?;
50 }
51
52 Ok(())
53 }
54
55 pub fn finish(mut self) -> io::Result<W> {
58 self.align()?;
59 Ok(self.writer)
60 }
61
62 pub fn align(&mut self) -> io::Result<()> {
64 if !self.is_aligned() {
65 self.write_bits(0, 8 - self.bit_pos())?;
66 }
67
68 Ok(())
69 }
70}
71
72impl<W> BitWriter<W> {
73 pub const fn new(writer: W) -> Self {
75 Self {
76 bit_pos: 0,
77 current_byte: 0,
78 writer,
79 }
80 }
81
82 #[inline(always)]
84 #[must_use]
85 pub const fn bit_pos(&self) -> u8 {
86 self.bit_pos % 8
87 }
88
89 #[inline(always)]
91 #[must_use]
92 pub const fn is_aligned(&self) -> bool {
93 self.bit_pos % 8 == 0
94 }
95
96 #[inline(always)]
98 #[must_use]
99 pub const fn get_ref(&self) -> &W {
100 &self.writer
101 }
102}
103
104impl<W: io::Write> io::Write for BitWriter<W> {
105 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
106 if self.is_aligned() {
107 return self.writer.write(buf);
108 }
109
110 for byte in buf {
111 self.write_bits(*byte as u64, 8)?;
112 }
113
114 Ok(buf.len())
115 }
116
117 fn flush(&mut self) -> io::Result<()> {
118 self.writer.flush()
119 }
120}
121
122#[cfg(test)]
123#[cfg_attr(all(test, coverage_nightly), coverage(off))]
124mod tests {
125 use io::Write;
126
127 use super::*;
128
129 #[test]
130 fn test_bit_writer() {
131 let mut bit_writer = BitWriter::<Vec<u8>>::default();
132
133 bit_writer.write_bits(0b11111111, 8).unwrap();
134 assert_eq!(bit_writer.bit_pos(), 0);
135 assert!(bit_writer.is_aligned());
136
137 bit_writer.write_bits(0b0000, 4).unwrap();
138 assert_eq!(bit_writer.bit_pos(), 4);
139 assert!(!bit_writer.is_aligned());
140 bit_writer.align().unwrap();
141 assert_eq!(bit_writer.bit_pos(), 0);
142 assert!(bit_writer.is_aligned());
143
144 bit_writer.write_bits(0b1010, 4).unwrap();
145 assert_eq!(bit_writer.bit_pos(), 4);
146 assert!(!bit_writer.is_aligned());
147
148 bit_writer.write_bits(0b101010101010, 12).unwrap();
149 assert_eq!(bit_writer.bit_pos(), 0);
150 assert!(bit_writer.is_aligned());
151
152 bit_writer.write_bit(true).unwrap();
153 assert_eq!(bit_writer.bit_pos(), 1);
154 assert!(!bit_writer.is_aligned());
155
156 let err = bit_writer.write_bits(0b10000, 4).unwrap_err();
157 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
158 assert_eq!(err.to_string(), "bits too large to write");
159
160 assert_eq!(
161 bit_writer.finish().unwrap(),
162 vec![0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b10000000]
163 );
164 }
165
166 #[test]
167 fn test_flush_buffer() {
168 let mut bit_writer = BitWriter::<Vec<u8>>::default();
169
170 bit_writer.write_bits(0b11111111, 8).unwrap();
171 assert_eq!(bit_writer.bit_pos(), 0);
172 assert!(bit_writer.is_aligned());
173 assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one byte");
174
175 bit_writer.write_bits(0b0000, 4).unwrap();
176 assert_eq!(bit_writer.bit_pos(), 4);
177 assert!(!bit_writer.is_aligned());
178 assert_eq!(bit_writer.get_ref(), &[0b11111111], "underlying writer should have one bytes");
179
180 bit_writer.write_bits(0b1010, 4).unwrap();
181 assert_eq!(bit_writer.bit_pos(), 0);
182 assert!(bit_writer.is_aligned());
183 assert_eq!(
184 bit_writer.get_ref(),
185 &[0b11111111, 0b00001010],
186 "underlying writer should have two bytes"
187 );
188 }
189
190 #[test]
191 fn test_io_write() {
192 let mut inner = Vec::new();
193 let mut bit_writer = BitWriter::new(&mut inner);
194
195 bit_writer.write_bits(0b11111111, 8).unwrap();
196 assert_eq!(bit_writer.bit_pos(), 0);
197 assert!(bit_writer.is_aligned());
198 assert_eq!(bit_writer.get_ref().as_slice(), &[255]);
200
201 bit_writer.write_all(&[1, 2, 3]).unwrap();
202 assert_eq!(bit_writer.bit_pos(), 0);
203 assert!(bit_writer.is_aligned());
204 assert_eq!(bit_writer.get_ref().as_slice(), &[255, 1, 2, 3]);
208
209 bit_writer.write_bit(true).unwrap();
210
211 bit_writer.write_bits(0b1010, 4).unwrap();
212
213 bit_writer
214 .write_all(&[0b11111111, 0b00000000, 0b11111111, 0b00000000])
215 .unwrap();
216
217 assert_eq!(
219 bit_writer.get_ref().as_slice(),
220 &[255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000]
221 );
222
223 bit_writer.finish().unwrap();
224
225 assert_eq!(
226 inner,
227 vec![255, 1, 2, 3, 0b11010111, 0b11111000, 0b00000111, 0b11111000, 0b00000000]
228 );
229 }
230
231 #[test]
232 fn test_flush() {
233 let mut inner = Vec::new();
234 let mut bit_writer = BitWriter::new(&mut inner);
235
236 bit_writer.write_bits(0b10100000, 8).unwrap();
237
238 bit_writer.flush().unwrap();
239
240 assert_eq!(bit_writer.get_ref().as_slice(), &[0b10100000]);
241 assert_eq!(bit_writer.bit_pos(), 0);
242 assert!(bit_writer.is_aligned());
243 }
244}