Skip to content

Commit

Permalink
Use const generics to remove BitTree heap allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
chyyran committed Aug 9, 2022
1 parent a010cc0 commit 02ecc33
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 97 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env_logger = { version = "^0.8.3", optional = true }

[dev-dependencies]
rust-lzma = "0.5"
seq-macro = "0.3"

[features]
enable_logging = ["env_logger", "log"]
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![Documentation](https://docs.rs/lzma-rs/badge.svg)](https://docs.rs/lzma-rs)
[![Safety Dance](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/)
![Build Status](https://github.com/gendx/lzma-rs/workflows/Build%20and%20run%20tests/badge.svg)
[![Minimum rust 1.50](https://img.shields.io/badge/rust-1.50%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1500-2021-02-11)
[![Minimum rust 1.51](https://img.shields.io/badge/rust-1.51%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1510-2021-03-25)

This project is a decoder for LZMA and its variants written in pure Rust, with focus on clarity.
It already supports LZMA, LZMA2 and a subset of the `.xz` file format.
Expand Down
29 changes: 17 additions & 12 deletions src/decode/lzma.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::decode::lzbuffer::{LzBuffer, LzCircularBuffer};
use crate::decode::rangecoder::{BitTree, LenDecoder, RangeDecoder};
use crate::decode::rangecoder::{bittree_probs_len, BitTree, LenDecoder, RangeDecoder};
use crate::decompress::{Options, UnpackedSize};
use crate::error;
use crate::util::vec2d::Vec2D;
Expand Down Expand Up @@ -167,8 +167,8 @@ pub(crate) struct DecoderState {
pub(crate) lzma_props: LzmaProperties,
unpacked_size: Option<u64>,
literal_probs: Vec2D<u16>,
pos_slot_decoder: [BitTree; 4],
align_decoder: BitTree,
pos_slot_decoder: [BitTree<6, { bittree_probs_len::<6>() }>; 4],
align_decoder: BitTree<4, { bittree_probs_len::<4>() }>,
pos_decoders: [u16; 115],
is_match: [u16; 192], // true = LZ, false = literal
is_rep: [u16; 12],
Expand All @@ -191,12 +191,12 @@ impl DecoderState {
unpacked_size,
literal_probs: Vec2D::init(0x400, (1 << (lzma_props.lc + lzma_props.lp), 0x300)),
pos_slot_decoder: [
BitTree::new(6),
BitTree::new(6),
BitTree::new(6),
BitTree::new(6),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
],
align_decoder: BitTree::new(4),
align_decoder: BitTree::new(),
pos_decoders: [0x400; 115],
is_match: [0x400; 192],
is_rep: [0x400; 12],
Expand All @@ -222,8 +222,13 @@ impl DecoderState {
}

self.lzma_props = new_props;
self.pos_slot_decoder.iter_mut().for_each(|t| t.reset());
self.align_decoder.reset();
self.pos_slot_decoder = [
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
];
self.align_decoder = BitTree::new();
self.pos_decoders = [0x400; 115];
self.is_match = [0x400; 192];
self.is_rep = [0x400; 12];
Expand All @@ -233,8 +238,8 @@ impl DecoderState {
self.is_rep_0long = [0x400; 192];
self.state = 0;
self.rep = [0; 4];
self.len_decoder.reset();
self.rep_len_decoder.reset();
self.len_decoder = LenDecoder::new();
self.rep_len_decoder = LenDecoder::new();
}

pub fn set_unpacked_size(&mut self, unpacked_size: Option<u64>) {
Expand Down
121 changes: 63 additions & 58 deletions src/decode/rangecoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,18 +150,35 @@ where
}
}

// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this
/// macro for compile-time const assertions
macro_rules! const_assert {
($($list:ident : $ty:ty),* => $expr:expr) => {{
struct Assert<$(const $list: $ty,)*>;
impl<$(const $list: $ty,)*> Assert<$($list,)*> {
const OK: u8 = 0 - !($expr) as u8;
}
Assert::<$($list,)*>::OK
}};
($expr:expr) => {
const OK: u8 = 0 - !($expr) as u8;
};
}

// const fn helper to parameterize the length of the bittree probability array.
pub const fn bittree_probs_len<const NUM_BITS: usize>() -> usize {
1 << NUM_BITS
}

#[derive(Debug, Clone)]
pub struct BitTree {
num_bits: usize,
probs: Vec<u16>,
pub struct BitTree<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> {
probs: [u16; PROBS_ARRAY_LEN],
}

impl BitTree {
pub fn new(num_bits: usize) -> Self {
impl<const NUM_BITS: usize, const PROBS_ARRAY_LEN: usize> BitTree<NUM_BITS, PROBS_ARRAY_LEN> {
pub fn new() -> Self {
const_assert!(NUM_BITS: usize, PROBS_ARRAY_LEN: usize => PROBS_ARRAY_LEN == bittree_probs_len::<NUM_BITS>());
BitTree {
num_bits,
probs: vec![0x400; 1 << num_bits],
probs: [0x400; PROBS_ARRAY_LEN],
}
}

Expand All @@ -170,29 +187,25 @@ impl BitTree {
rangecoder: &mut RangeDecoder<R>,
update: bool,
) -> io::Result<u32> {
rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update)
rangecoder.parse_bit_tree(NUM_BITS, &mut self.probs, update)
}

pub fn parse_reverse<R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<R>,
update: bool,
) -> io::Result<u32> {
rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update)
}

pub fn reset(&mut self) {
self.probs.fill(0x400);
rangecoder.parse_reverse_bit_tree(NUM_BITS, &mut self.probs, 0, update)
}
}

#[derive(Debug)]
pub struct LenDecoder {
choice: u16,
choice2: u16,
low_coder: [BitTree; 16],
mid_coder: [BitTree; 16],
high_coder: BitTree,
low_coder: [BitTree<3, { bittree_probs_len::<3>() }>; 16],
mid_coder: [BitTree<3, { bittree_probs_len::<3>() }>; 16],
high_coder: BitTree<8, { bittree_probs_len::<8>() }>,
}

impl LenDecoder {
Expand All @@ -201,42 +214,42 @@ impl LenDecoder {
choice: 0x400,
choice2: 0x400,
low_coder: [
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
],
mid_coder: [
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
],
high_coder: BitTree::new(8),
high_coder: BitTree::new(),
}
}

Expand All @@ -254,12 +267,4 @@ impl LenDecoder {
Ok(self.high_coder.parse(rangecoder, update)? as usize + 16)
}
}

pub fn reset(&mut self) {
self.choice = 0x400;
self.choice2 = 0x400;
self.low_coder.iter_mut().for_each(|t| t.reset());
self.mid_coder.iter_mut().for_each(|t| t.reset());
self.high_coder.reset();
}
}
61 changes: 35 additions & 26 deletions src/encode/rangecoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ mod test {
use super::*;
use crate::decode::rangecoder::{LenDecoder, RangeDecoder};
use crate::{decode, encode};
use seq_macro::seq;
use std::io::BufReader;

fn encode_decode(prob_init: u16, bits: &[bool]) {
Expand Down Expand Up @@ -253,19 +254,19 @@ mod test {
encode_decode(0x400, &[true; 10000]);
}

fn encode_decode_bittree(num_bits: usize, values: &[u32]) {
fn encode_decode_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(values: &[u32]) {
let mut buf: Vec<u8> = Vec::new();

let mut encoder = RangeEncoder::new(&mut buf);
let mut tree = encode::rangecoder::BitTree::new(num_bits);
let mut tree = encode::rangecoder::BitTree::new(NUM_BITS);
for &v in values {
tree.encode(&mut encoder, v).unwrap();
}
encoder.finish().unwrap();

let mut bufread = BufReader::new(buf.as_slice());
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
let mut tree = decode::rangecoder::BitTree::new(num_bits);
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
for &v in values {
assert_eq!(tree.parse(&mut decoder, true).unwrap(), v);
}
Expand All @@ -274,40 +275,45 @@ mod test {

#[test]
fn test_encode_decode_bittree_zeros() {
for num_bits in 0..16 {
encode_decode_bittree(num_bits, &[0; 10000]);
}
seq!(NUM_BITS in 0..16 {
encode_decode_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
(&[0; 10000]);
});
}

#[test]
fn test_encode_decode_bittree_ones() {
for num_bits in 0..16 {
encode_decode_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
}
seq!(NUM_BITS in 0..16 {
encode_decode_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
(&[(1 << NUM_BITS) - 1; 10000]);
});
}

#[test]
fn test_encode_decode_bittree_all() {
for num_bits in 0..16 {
let max = 1 << num_bits;
seq!(NUM_BITS in 0..16 {
let max = 1 << NUM_BITS;
let values: Vec<u32> = (0..max).collect();
encode_decode_bittree(num_bits, &values);
}
encode_decode_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
(&values);
});
}

fn encode_decode_reverse_bittree(num_bits: usize, values: &[u32]) {
fn encode_decode_reverse_bittree<const NUM_BITS: usize, const PROBS_LEN: usize>(
values: &[u32],
) {
let mut buf: Vec<u8> = Vec::new();

let mut encoder = RangeEncoder::new(&mut buf);
let mut tree = encode::rangecoder::BitTree::new(num_bits);
let mut tree = encode::rangecoder::BitTree::new(NUM_BITS);
for &v in values {
tree.encode_reverse(&mut encoder, v).unwrap();
}
encoder.finish().unwrap();

let mut bufread = BufReader::new(buf.as_slice());
let mut decoder = RangeDecoder::new(&mut bufread).unwrap();
let mut tree = decode::rangecoder::BitTree::new(num_bits);
let mut tree = decode::rangecoder::BitTree::<NUM_BITS, PROBS_LEN>::new();
for &v in values {
assert_eq!(tree.parse_reverse(&mut decoder, true).unwrap(), v);
}
Expand All @@ -316,25 +322,28 @@ mod test {

#[test]
fn test_encode_decode_reverse_bittree_zeros() {
for num_bits in 0..16 {
encode_decode_reverse_bittree(num_bits, &[0; 10000]);
}
seq!(NUM_BITS in 0..16 {
encode_decode_reverse_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
(&[0; 10000]);
});
}

#[test]
fn test_encode_decode_reverse_bittree_ones() {
for num_bits in 0..16 {
encode_decode_reverse_bittree(num_bits, &[(1 << num_bits) - 1; 10000]);
}
seq!(NUM_BITS in 0..16 {
encode_decode_reverse_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
(&[(1 << NUM_BITS) - 1; 10000]);
});
}

#[test]
fn test_encode_decode_reverse_bittree_all() {
for num_bits in 0..16 {
let max = 1 << num_bits;
seq!(NUM_BITS in 0..16 {
let max = 1 << NUM_BITS;
let values: Vec<u32> = (0..max).collect();
encode_decode_reverse_bittree(num_bits, &values);
}
encode_decode_reverse_bittree::<NUM_BITS, {decode::rangecoder::bittree_probs_len::<NUM_BITS>()}>
(&values);
});
}

fn encode_decode_length(pos_state: usize, values: &[u32]) {
Expand Down

0 comments on commit 02ecc33

Please sign in to comment.