use std::fmt::{Display, Formatter}; use num_bigint::BigUint; #[derive(Clone, Debug)] pub enum ByteStreamError { OutOfRange, } impl Display for ByteStreamError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { ByteStreamError::OutOfRange => write!(f, "Requested values out of range") } } } pub const fn bit_mask(mask: u8) -> u8 { match mask { 0 => 0x00, 1 => 0x01, 2 => 0x03, 3 => 0x07, 4 => 0x0f, 5 => 0x1f, 6 => 0x3f, 7 => 0x7f, _ => 0xff, } } #[derive(Default, Clone, Debug)] pub struct ByteStream { data: Vec, } impl ByteStream { pub fn get_bytes(&self, bit_ndx: usize, bit_count: usize) -> Result, ByteStreamError> { let byte_ndx = bit_ndx / 8; let bits_before = (bit_ndx % 8) as u8; let bits_in_last_byte = ((bit_ndx + bit_count - bits_before as usize) % 8) as u8; let byte_count = ((bit_count as f32) / (8.0)).ceil() as usize; let bytes_needed = if (bits_before as usize + bit_count) % 8 != 0 && !(bits_before as usize + bit_count < 8){ byte_count + 1 } else { byte_count }; if bytes_needed > self.data.len() || (bytes_needed + byte_ndx) > self.data.len() { return Err(ByteStreamError::OutOfRange); } let byte_stream = self.data[byte_ndx..byte_ndx + bytes_needed].to_vec(); let number = BigUint::from_bytes_le(&byte_stream) >> bits_before; let mut byte_stream = number.to_bytes_le(); if bytes_needed > byte_count && bytes_needed == byte_stream.len() { byte_stream.pop(); } if bits_in_last_byte != 0 { *byte_stream.last_mut().unwrap() &= bit_mask(bits_in_last_byte); } Ok(byte_stream) } pub fn len(&self) -> usize { self.data.len() } } impl From<&[u8]> for ByteStream { fn from(slice: &[u8]) -> Self { ByteStream::from(slice.to_vec()) } } impl From> for ByteStream { fn from(vec: Vec) -> Self { ByteStream { data: vec } } } #[cfg(test)] mod tests { use super::ByteStream; #[test] fn test_get_bytes_no_shift() { let bytes: Vec = vec![0xff, 0x00, 0x55]; let bit_stream = ByteStream::from(bytes.clone()); let new_bytes = bit_stream.get_bytes(0, bytes.len() * 8).unwrap(); assert_eq!(bytes, new_bytes); } #[test] fn test_get_bytes_with_shift_in_byte() { let bytes: Vec = vec![0x5f, 0x00, 0x55]; let bit_stream = ByteStream::from(bytes.clone()); let new_bytes = bit_stream.get_bytes(4, 4).unwrap(); assert_eq!(vec![0x05], new_bytes); } #[test] fn test_get_bytes_with_shift_across_bytes() { let bytes: Vec = vec![0xff, 0x55]; let bit_stream = ByteStream::from(bytes.clone()); let new_bytes = bit_stream.get_bytes(4, 8).unwrap(); assert_eq!(vec![0x5f], new_bytes); } #[test] fn test_get_bytes_with_shift_across_bytes_odd() { let bytes: Vec = vec![0xff, 0x55]; let bit_stream = ByteStream::from(bytes.clone()); let new_bytes = bit_stream.get_bytes(7, 2).unwrap(); assert_eq!(vec![0x03], new_bytes); } #[test] fn test_get_bytes_one_byte() { let bytes: Vec = vec![0xff, 0x55]; let bit_stream = ByteStream::from(bytes.clone()); let new_bytes = bit_stream.get_bytes(8, 8).unwrap(); assert_eq!(vec![0x55], new_bytes); } #[test] fn test_get_bytes_3_bits() { let bytes: Vec = vec![0x55]; let bit_stream = ByteStream::from(bytes.clone()); let new_bytes = bit_stream.get_bytes(0, 3).unwrap(); assert_eq!(vec![0x05], new_bytes); } }