Skip to content

Commit

Permalink
Use const generics to remove BitTree heap allocations
Browse files Browse the repository at this point in the history
The single-argument that BitTree takes is 1 << NUM_BITS (2 ** NUM_BITS)
for the number of bits required in the tree.

This is due to restrictions on const generic expressions.
The validity of this argument is checked at compile-time with a macro
that confirms that the argument P passed is indeed 1 << N for
some N using usize::trailing_zeros to calculate floor(log_2(P)).

Thus, BitTree<const P: usize> is only valid for any P such that
P = 2 ** floor(log_2(P)), where P is the length of the probability array
of the BitTree. This maintains the invariant that P = 1 << N.
  • Loading branch information
chyyran authored and gendx committed Mar 16, 2023
1 parent da82bd1 commit 81cee7d
Show file tree
Hide file tree
Showing 7 changed files with 197 additions and 115 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
- stable
- beta
- nightly
- 1.50.0 # MSRV
- 1.57.0 # MSRV
fail-fast: false
runs-on: ${{ matrix.os }}
env:
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env_logger = { version = "0.9.0", optional = true }

[dev-dependencies]
rust-lzma = "0.5"
seq-macro = "0.3"

[features]
enable_logging = ["env_logger", "log"]
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
[![Documentation](https://docs.rs/lzma-rs/badge.svg)](https://docs.rs/lzma-rs)
[![Safety Dance](https://img.shields.io/badge/unsafe-forbidden-success.svg)](https://github.com/rust-secure-code/safety-dance/)
![Build Status](https://github.com/gendx/lzma-rs/workflows/Build%20and%20run%20tests/badge.svg)
[![Minimum rust 1.50](https://img.shields.io/badge/rust-1.50%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1500-2021-02-11)
[![Codecov](https://codecov.io/gh/gendx/lzma-rs/branch/master/graph/badge.svg?token=HVo74E0wzh)](https://codecov.io/gh/gendx/lzma-rs)
[![Minimum rust 1.57](https://img.shields.io/badge/rust-1.57%2B-orange.svg)](https://github.com/rust-lang/rust/blob/master/RELEASES.md#version-1510-2021-03-25)

This project is a decoder for LZMA and its variants written in pure Rust, with focus on clarity.
It already supports LZMA, LZMA2 and a subset of the `.xz` file format.
Expand Down
27 changes: 16 additions & 11 deletions src/decode/lzma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ pub(crate) struct DecoderState {
pub(crate) lzma_props: LzmaProperties,
unpacked_size: Option<u64>,
literal_probs: Vec2D<u16>,
pos_slot_decoder: [BitTree; 4],
align_decoder: BitTree,
pos_slot_decoder: [BitTree<{ 1 << 6 }>; 4],
align_decoder: BitTree<{ 1 << 4 }>,
pos_decoders: [u16; 115],
is_match: [u16; 192], // true = LZ, false = literal
is_rep: [u16; 12],
Expand All @@ -191,12 +191,12 @@ impl DecoderState {
unpacked_size,
literal_probs: Vec2D::init(0x400, (1 << (lzma_props.lc + lzma_props.lp), 0x300)),
pos_slot_decoder: [
BitTree::new(6),
BitTree::new(6),
BitTree::new(6),
BitTree::new(6),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
],
align_decoder: BitTree::new(4),
align_decoder: BitTree::new(),
pos_decoders: [0x400; 115],
is_match: [0x400; 192],
is_rep: [0x400; 12],
Expand All @@ -222,11 +222,16 @@ impl DecoderState {
}

self.lzma_props = new_props;
self.pos_slot_decoder.iter_mut().for_each(|t| t.reset());
self.align_decoder.reset();
// For stack-allocated arrays, it was found to be faster to re-create new arrays
// dropping the existing one, rather than using `fill` to reset the contents to zero.
// Heap-based arrays use fill to keep their allocation rather than reallocate.
self.pos_slot_decoder = [
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
];
self.align_decoder = BitTree::new();
self.pos_decoders = [0x400; 115];
self.is_match = [0x400; 192];
self.is_rep = [0x400; 12];
Expand All @@ -236,8 +241,8 @@ impl DecoderState {
self.is_rep_0long = [0x400; 192];
self.state = 0;
self.rep = [0; 4];
self.len_decoder.reset();
self.rep_len_decoder.reset();
self.len_decoder = LenDecoder::new();
self.rep_len_decoder = LenDecoder::new();
}

pub fn set_unpacked_size(&mut self, unpacked_size: Option<u64>) {
Expand Down
120 changes: 62 additions & 58 deletions src/decode/rangecoder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::decode::util;
use crate::error;
use crate::util::const_assert;
use byteorder::{BigEndian, ReadBytesExt};
use std::io;

Expand Down Expand Up @@ -150,49 +151,60 @@ where
}
}

// TODO: parametrize by constant and use [u16; 1 << num_bits] as soon as Rust supports this
#[derive(Debug, Clone)]
pub struct BitTree {
num_bits: usize,
probs: Vec<u16>,
pub struct BitTree<const PROBS_ARRAY_LEN: usize> {
probs: [u16; PROBS_ARRAY_LEN],
}

impl BitTree {
pub fn new(num_bits: usize) -> Self {
impl<const PROBS_ARRAY_LEN: usize> BitTree<PROBS_ARRAY_LEN> {
pub fn new() -> Self {
// The validity of PROBS_ARRAY_LEN is checked at compile-time with a macro
// that confirms that the argument P passed is indeed 1 << N for
// some N using usize::trailing_zeros to calculate floor(log_2(P)).
//
// Thus, BitTree<const P: usize> is only valid for any P such that
// P = 2 ** floor(log_2(P)), where P is the length of the probability array
// of the BitTree. This maintains the invariant that P = 1 << N.
//
// This precondition must be checked for any way to construct a new, valid instance of BitTree.
// Here it is checked for BitTree::new(), but if another function is added that returns a
// new instance of BitTree, this assertion must be checked there as well.
const_assert!("BitTree's PROBS_ARRAY_LEN parameter must be a power of 2",
PROBS_ARRAY_LEN: usize => (1 << (PROBS_ARRAY_LEN.trailing_zeros() as usize)) == PROBS_ARRAY_LEN);
BitTree {
num_bits,
probs: vec![0x400; 1 << num_bits],
probs: [0x400; PROBS_ARRAY_LEN],
}
}

// NUM_BITS is derived from PROBS_ARRAY_LEN because of the lack of
// generic const expressions. Where PROBS_ARRAY_LEN is a power of 2,
// NUM_BITS can be derived by the number of trailing zeroes.
const NUM_BITS: usize = PROBS_ARRAY_LEN.trailing_zeros() as usize;

pub fn parse<R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<R>,
update: bool,
) -> io::Result<u32> {
rangecoder.parse_bit_tree(self.num_bits, self.probs.as_mut_slice(), update)
rangecoder.parse_bit_tree(Self::NUM_BITS, &mut self.probs, update)
}

pub fn parse_reverse<R: io::BufRead>(
&mut self,
rangecoder: &mut RangeDecoder<R>,
update: bool,
) -> io::Result<u32> {
rangecoder.parse_reverse_bit_tree(self.num_bits, self.probs.as_mut_slice(), 0, update)
}

pub fn reset(&mut self) {
self.probs.fill(0x400);
rangecoder.parse_reverse_bit_tree(Self::NUM_BITS, &mut self.probs, 0, update)
}
}

#[derive(Debug)]
pub struct LenDecoder {
choice: u16,
choice2: u16,
low_coder: [BitTree; 16],
mid_coder: [BitTree; 16],
high_coder: BitTree,
low_coder: [BitTree<{ 1 << 3 }>; 16],
mid_coder: [BitTree<{ 1 << 3 }>; 16],
high_coder: BitTree<{ 1 << 8 }>,
}

impl LenDecoder {
Expand All @@ -201,42 +213,42 @@ impl LenDecoder {
choice: 0x400,
choice2: 0x400,
low_coder: [
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
],
mid_coder: [
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(3),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
BitTree::new(),
],
high_coder: BitTree::new(8),
high_coder: BitTree::new(),
}
}

Expand All @@ -254,12 +266,4 @@ impl LenDecoder {
Ok(self.high_coder.parse(rangecoder, update)? as usize + 16)
}
}

pub fn reset(&mut self) {
self.choice = 0x400;
self.choice2 = 0x400;
self.low_coder.iter_mut().for_each(|t| t.reset());
self.mid_coder.iter_mut().for_each(|t| t.reset());
self.high_coder.reset();
}
}

0 comments on commit 81cee7d

Please sign in to comment.