Skip to content

Commit

Permalink
modules/zstd: Add proc for generating symbol order for FSE lookup
Browse files Browse the repository at this point in the history
Internal-tag: [#57353]
Signed-off-by: Robert Winkler <rwinkler@antmicro.com>
  • Loading branch information
rw1nkler committed Apr 4, 2024
1 parent 2d52296 commit 6c3b033
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
14 changes: 14 additions & 0 deletions xls/modules/zstd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,17 @@ xls_dslx_test(
dslx_test_args = {"compare": "none"},
library = ":fse_ml_proba_provider_dslx",
)

xls_dslx_library(
name = "fse_table_iterator_dslx",
srcs = ["fse_table_iterator.x"],
deps = [
":fse_common_dslx",
],
)

xls_dslx_test(
name = "fse_table_iterator_dslx_test",
dslx_test_args = {"compare": "none"},
library = ":fse_table_iterator_dslx",
)
18 changes: 18 additions & 0 deletions xls/modules/zstd/build_defs.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("//xls/build_rules:xls_dslx_rules.bzl", "xls_dslx_library")
load("@bazel_skylib//rules:expand_template.bzl", "expand_template")

def dslx_library_from_template(name, template, substitutions, **kwargs):
expanded_file = name + ".x"

expand_template(
name = name + "_dslx",
template = template,
substitutions = substitutions,
out = expanded_file
)

xls_dslx_library(
name = name,
srcs = [expanded_file],
**kwargs
)
139 changes: 139 additions & 0 deletions xls/modules/zstd/fse_table_iterator.x
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2024 The XLS Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import std;
import xls.modules.zstd.fse_common as fse;

pub struct FSETableCreatorCtrl {
accuracy_log: u32,
negative_proba_count: u32
}

type Reset = bool;
type Index = u32;
type Ctrl = FSETableCreatorCtrl;

enum Status : u1 {
CONFIGURE = 0,
SEND = 1,
}

struct State {
status: Status,
ctrl: Ctrl,
cnt: u32,
pos: u32
}

proc FSETableIterator {
ctrl_r: chan<Ctrl> in;
idx_s: chan<Index> out;

config(
ctrl_r: chan<Ctrl> in,
idx_s: chan<Index> out
) { (ctrl_r, idx_s) }

init { zero!<State>() }

next(tok0: token, state: State) {
const ZERO_STATE = zero!<State>();
const ZERO_IDX_OPTION = (false, u32:0);

let do_recv_ctrl = state.status == Status::CONFIGURE;
let (tok1, ctrl) = recv_if(tok0, ctrl_r, do_recv_ctrl, zero!<Ctrl>());
if do_recv_ctrl {
trace_fmt!("[IO]: Received ctrl: {}", ctrl);
} else { };

let (idx_option, new_state) = match (state.status) {
Status::CONFIGURE => {
trace_fmt!("[STATE]: CONFIGURE");
((true, u32:0), State { ctrl, status: Status::SEND, ..ZERO_STATE })
},
Status::SEND => {
trace_fmt!("[STATE]: SEND");

let size = u32:1 << state.ctrl.accuracy_log;
let high_threshold = size - state.ctrl.negative_proba_count;
let step = (size >> 1) + (size >> 3) + u32:3;
let mask = size - u32:1;

trace_fmt!("[ITERATOR]: size: {}", size);
trace_fmt!("[ITERATOR]: high_threshold: {}", high_threshold);
trace_fmt!("[ITERATOR]: step: {}", step);
trace_fmt!("[ITERATOR]: mask: {}", mask);

let pos = (state.pos + step) & mask;

let valid = pos < high_threshold;
let next_cnt = state.cnt + u32:1;
let last = (valid && (next_cnt == high_threshold - u32:1));

if last {
((true, pos), ZERO_STATE)
} else if valid {
((true, pos), State { cnt: next_cnt, pos, ..state })
} else {
(ZERO_IDX_OPTION, State { cnt: state.cnt, pos, ..state })
}
},
_ => fail!("incorrect_state", (ZERO_IDX_OPTION, ZERO_STATE)),
};

let (do_send_idx, idx) = idx_option;
let tok2 = send_if(tok1, idx_s, do_send_idx, idx);
if do_send_idx { trace_fmt!("[IO]: Send index: {}", idx); } else { };

new_state
}
}

const TEST_EXPECTRED_IDX = u32[27]:[
u32:0, u32:23, u32:14, u32:5, u32:19, u32:10, u32:1, u32:24, u32:15, u32:6, u32:20, u32:11,
u32:2, u32:25, u32:16, u32:7, u32:21, u32:12, u32:3, u32:26, u32:17, u32:8, u32:22, u32:13,
u32:4, u32:18, u32:9,
];

#[test_proc]
proc FSETableIteratorTest {
terminator: chan<bool> out;
ctrl_s: chan<Ctrl> out;
idx_r: chan<Index> in;

config(terminator: chan<bool> out) {
let (ctrl_s, ctrl_r) = chan<Ctrl>;
let (idx_s, idx_r) = chan<Index>;

spawn FSETableIterator(ctrl_r, idx_s);
(terminator, ctrl_s, idx_r)
}

init { }

next(tok: token, state: ()) {
let tok = send(tok, ctrl_s, Ctrl {
accuracy_log: u32:5,
negative_proba_count: u32:5
});

let tok = for (exp_idx, tok): (Index, token) in TEST_EXPECTRED_IDX {
let (tok, idx) = recv(tok, idx_r);
assert_eq(idx, exp_idx);
(tok)
}(tok);

send(tok, terminator, true);
}
}

0 comments on commit 6c3b033

Please sign in to comment.