Skip to content

Commit

Permalink
Alternative API with only a single generic argument for BitTree.
Browse files Browse the repository at this point in the history
The single-argument that BitTree takes is 1 << NUM_BITS (2 ** NUM_BITS)
for the number of bits required in the tree.

This is due to restrictions on const generic expressions.
The validity of this argument is checked at compile-time with a macro
that confirms that the argument P passed is indeed 1 << N for
some N using usize::trailing_zeros to calculate floor(log_2(P)).

Thus, BitTree<const P: usize> is only valid for any P such that
P = 2 ** floor(log_2(P)), where P is the length of the probability array
of the BitTree. This maintains the invariant that P = 1 << N.
  • Loading branch information
chyyran committed Sep 3, 2022
1 parent 08b6794 commit e211b5d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 32 deletions.
4 changes: 2 additions & 2 deletions src/decode/lzma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ pub(crate) struct DecoderState {
pub(crate) lzma_props: LzmaProperties,
unpacked_size: Option<u64>,
literal_probs: Vec2D<u16>,
pos_slot_decoder: [BitTree<6, { 1 << 6 }>; 4],
align_decoder: BitTree<4, { 1 << 4 }>,
pos_slot_decoder: [BitTree<{ 1 << 6 }>; 4],
align_decoder: BitTree<{ 1 << 4 }>,
pos_decoders: [u16; 115],
is_match: [u16; 192], // true = LZ, false = literal
is_rep: [u16; 12],
Expand Down
25 changes: 17 additions & 8 deletions src/decode/rangecoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,42 +152,51 @@ where
}

#[derive(Debug, Clone)]
pub struct BitTree<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> {
pub struct BitTree<const PROBS_ARRAY_LEN: usize> {
probs: [u16; PROBS_ARRAY_LEN],
}

impl<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> BitTree<NUM_BITS, PROBS_ARRAY_LEN> {
impl<const PROBS_ARRAY_LEN: usize> BitTree<PROBS_ARRAY_LEN> {
pub fn new() -> Self {
const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS);
// The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro
// that confirms that the argument P passed is indeed 1 << N for
// some N using usize::trailing_zeros to calculate floor(log_2(P)).
//
// Thus, BitTree<const P: usize> is only valid for any P such that
// P = 2 ** floor(log_2(P)), where P is the length of the probability array
// of the BitTree. This maintains the invariant that P = 1 << N.
const_assert!(PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN);
BitTree {
probs: [0x400; PROBS_ARRAY_LEN],
}
}

const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize;

pub fn parse<R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<R>,
update: bool,
) -> io::Result<u32> {
rangecoder.parse_bit_tree(NUM_BITS, &mut self.probs, update)
rangecoder.parse_bit_tree(Self::NUM_BITS, &mut self.probs, update)
}

pub fn parse_reverse<R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<R>,
update: bool,
) -> io::Result<u32> {
rangecoder.parse_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, update)
rangecoder.parse_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, update)
}
}

#[derive(Debug)]
pub struct LenDecoder {
choice: u16,
choice2: u16,
low_coder: [BitTree<3, { 1 << 3 }>; 16],
mid_coder: [BitTree<3, { 1 << 3 }>; 16],
high_coder: BitTree<8, { 1 << 8 }>,
low_coder: [BitTree<{ 1 << 3 }>; 16],
mid_coder: [BitTree<{ 1 << 3 }>; 16],
high_coder: BitTree<{ 1 << 8 }>,
}

impl LenDecoder {
Expand Down
51 changes: 29 additions & 22 deletions src/encode/rangecoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,43 +145,52 @@ where

#[cfg(test)]
#[derive(Debug, Clone)]
pub struct BitTree<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> {
pub struct BitTree<const PROBS_ARRAY_LEN: usize> {
probs: [u16; PROBS_ARRAY_LEN],
}

#[cfg(test)]
impl<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> BitTree<NUM_BITS, PROBS_ARRAY_LEN> {
impl<const PROBS_ARRAY_LEN: usize> BitTree<PROBS_ARRAY_LEN> {
pub fn new() -> Self {
const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == 1 << NUM_BITS);
// The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro
// that confirms that the argument P passed is indeed 1 << N for
// some N using usize::trailing_zeros to calculate floor(log_2(P)).
//
// Thus, BitTree<const P: usize> is only valid for any P such that
// P = 2 ** floor(log_2(P)), where P is the length of the probability array
// of the BitTree. This maintains the invariant that P = 1 << N.
const_assert!(PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN);
BitTree {
probs: [0x400; PROBS_ARRAY_LEN],
}
}

const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize;

pub fn encode<W: io::Write>(
&mut self,
rangecoder: &mut RangeEncoder<W>,
value: u32,
) -> io::Result<()> {
rangecoder.encode_bit_tree(NUM_BITS, self.probs.as_mut_slice(), value)
rangecoder.encode_bit_tree(Self::NUM_BITS, self.probs.as_mut_slice(), value)
}

pub fn encode_reverse<W: io::Write>(
&mut self,
rangecoder: &mut RangeEncoder<W>,
value: u32,
) -> io::Result<()> {
rangecoder.encode_reverse_bit_tree(NUM_BITS, self.probs.as_mut_slice(), 0, value)
rangecoder.encode_reverse_bit_tree(Self::NUM_BITS, self.probs.as_mut_slice(), 0, value)
}
}

#[cfg(test)]
pub struct LenEncoder {
choice: u16,
choice2: u16,
low_coder: [BitTree<3, { 1 << 3 }>; 16],
mid_coder: [BitTree<3, { 1 << 3 }>; 16],
high_coder: BitTree<8, { 1 << 8 }>,
low_coder: [BitTree<{ 1 << 3 }>; 16],
mid_coder: [BitTree<{ 1 << 3 }>; 16],
high_coder: BitTree<{ 1 << 8 }>,
}

#[cfg(test)]
Expand Down Expand Up @@ -289,19 +298,19 @@ mod test {
encode_decode(0x400, &[true; 10000]);
}

fn encode_decode_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(values: &[u32]) {
fn encode_decode_bittree<const PROBS_LEN: usize>(values: &[u32]) {
let mut buf: Vec<u8> = Vec::new();

let mut encoder = RangeEncoder::new(&mut buf);
let mut tree = encode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
let mut tree = encode::rangecoder::BitTree::<PROBS_LEN>::new();
for &v in values {
tree.encode(&mut encoder, v).unwrap();
}
encoder.finish().unwrap();

let mut bufread = BufReader::new(buf.as_slice());
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
let mut tree = decode::rangecoder::BitTree::<PROBS_LEN>::new();
for &v in values {
assert_eq!(tree.parse(&mut decoder, true).unwrap(), v);
}
Expand All @@ -311,15 +320,15 @@ mod test {
#[test]
fn test_encode_decode_bittree_zeros() {
seq!(NUM_BITS in 0..16 {
encode_decode_bittree::<NUM_BITS, {1 << NUM_BITS}>
encode_decode_bittree::<{1 << NUM_BITS}>
(&[0; 10000]);
});
}

#[test]
fn test_encode_decode_bittree_ones() {
seq!(NUM_BITS in 0..16 {
encode_decode_bittree::<NUM_BITS, {1 << NUM_BITS}>
encode_decode_bittree::<{1 << NUM_BITS}>
(&[(1 << NUM_BITS) - 1; 10000]);
});
}
Expand All @@ -329,26 +338,24 @@ mod test {
seq!(NUM_BITS in 0..16 {
let max = 1 << NUM_BITS;
let values: Vec<u32> = (0..max).collect();
encode_decode_bittree::<NUM_BITS, {1 << NUM_BITS}>
encode_decode_bittree::<{1 << NUM_BITS}>
(&values);
});
}

fn encode_decode_reverse_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(
values: &[u32],
) {
fn encode_decode_reverse_bittree<const PROBS_LEN: usize>(values: &[u32]) {
let mut buf: Vec<u8> = Vec::new();

let mut encoder = RangeEncoder::new(&mut buf);
let mut tree = encode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
let mut tree = encode::rangecoder::BitTree::<PROBS_LEN>::new();
for &v in values {
tree.encode_reverse(&mut encoder, v).unwrap();
}
encoder.finish().unwrap();

let mut bufread = BufReader::new(buf.as_slice());
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
let mut tree = decode::rangecoder::BitTree::<PROBS_LEN>::new();
for &v in values {
assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v);
}
Expand All @@ -358,15 +365,15 @@ mod test {
#[test]
fn test_encode_decode_reverse_bittree_zeros() {
seq!(NUM_BITS in 0..16 {
encode_decode_reverse_bittree::<NUM_BITS, {1 << NUM_BITS}>
encode_decode_reverse_bittree::<{1 << NUM_BITS}>
(&[0; 10000]);
});
}

#[test]
fn test_encode_decode_reverse_bittree_ones() {
seq!(NUM_BITS in 0..16 {
encode_decode_reverse_bittree::<NUM_BITS, {1 << NUM_BITS}>
encode_decode_reverse_bittree::<{1 << NUM_BITS}>
(&[(1 << NUM_BITS) - 1; 10000]);
});
}
Expand All @@ -376,7 +383,7 @@ mod test {
seq!(NUM_BITS in 0..16 {
let max = 1 << NUM_BITS;
let values: Vec<u32> = (0..max).collect();
encode_decode_reverse_bittree::<NUM_BITS, {1 << NUM_BITS}>
encode_decode_reverse_bittree::<{1 << NUM_BITS}>
(&values);
});
}
Expand Down

0 comments on commit e211b5d

Please sign in to comment.