Skip to content

Commit

Permalink
Adding aarch64 neon support.
Browse files Browse the repository at this point in the history
  • Loading branch information
bitshifter committed May 15, 2024
1 parent 29413fa commit ec1a942
Show file tree
Hide file tree
Showing 66 changed files with 11,906 additions and 5,158 deletions.
18 changes: 18 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,21 @@ jobs:

- run: ./build_and_test_wasm32_firefox.sh
- run: ./build_and_test_wasm32_chrome.sh

# macos-latest uses aarch64 so using cross shouldn't be necessary anymore
# test-arm:
# name: Test Arm
# runs-on: ubuntu-latest
# strategy:
# fail-fast: false
# matrix:
# target:
# - aarch64-unknown-linux-gnu
# - arm-unknown-linux-gnueabi
# steps:
# - uses: actions/checkout@v4
# - run: rustup update --no-self-update stable
# - run: rustup default stable
# - run: rustup target add --toolchain stable ${{matrix.target}}
# - uses: taiki-e/install-action@cross
# - run: cross test --target ${{matrix.target}}
3 changes: 3 additions & 0 deletions codegen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ fn main() -> anyhow::Result<()> {

let full_output_path = workdir.join(output_path);

let output_dir = full_output_path.parent().unwrap();
std::fs::create_dir_all(output_dir)?;

if check {
match std::fs::read_to_string(&full_output_path) {
Ok(original_str) => {
Expand Down
43 changes: 43 additions & 0 deletions codegen/src/outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ enum Target {
Scalar,
Sse2,
Wasm32,
Neon,
CoreSimd,
}

Expand Down Expand Up @@ -357,6 +358,7 @@ impl ContextBuilder {
self.0.insert("is_sse2", &(target == Target::Sse2));
self.0.insert("is_coresimd", &(target == Target::CoreSimd));
self.0.insert("is_wasm32", &(target == Target::Wasm32));
self.0.insert("is_neon", &(target == Target::Neon));
self.0.insert("is_scalar", &(target == Target::Scalar));
self
}
Expand All @@ -365,6 +367,10 @@ impl ContextBuilder {
self.with_target(Target::Sse2)
}

pub fn target_neon(self) -> Self {
self.with_target(Target::Neon)
}

pub fn target_wasm32(self) -> Self {
self.with_target(Target::Wasm32)
}
Expand Down Expand Up @@ -422,6 +428,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
"src/swizzles/scalar/vec3a_impl.rs",
ContextBuilder::new_vec3a_swizzle_impl().build(),
),
(
"src/swizzles/neon/vec3a_impl.rs",
ContextBuilder::new_vec3a_swizzle_impl().build(),
),
(
"src/swizzles/sse2/vec3a_impl.rs",
ContextBuilder::new_vec3a_swizzle_impl()
Expand All @@ -444,6 +454,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
"src/swizzles/scalar/vec4_impl.rs",
ContextBuilder::new_vec4_swizzle_impl().build(),
),
(
"src/swizzles/neon/vec4_impl.rs",
ContextBuilder::new_vec4_swizzle_impl().build(),
),
(
"src/swizzles/sse2/vec4_impl.rs",
ContextBuilder::new_vec4_swizzle_impl()
Expand Down Expand Up @@ -574,6 +588,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
"src/bool/wasm32/bvec3a.rs",
ContextBuilder::new_bvec3a().target_wasm32().build(),
),
(
"src/bool/neon/bvec3a.rs",
ContextBuilder::new_bvec3a().target_neon().build(),
),
(
"src/bool/coresimd/bvec3a.rs",
ContextBuilder::new_bvec3a().target_coresimd().build(),
Expand All @@ -590,6 +608,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
"src/bool/wasm32/bvec4a.rs",
ContextBuilder::new_bvec4a().target_wasm32().build(),
),
(
"src/bool/neon/bvec4a.rs",
ContextBuilder::new_bvec4a().target_neon().build(),
),
(
"src/bool/coresimd/bvec4a.rs",
ContextBuilder::new_bvec4a().target_coresimd().build(),
Expand All @@ -600,6 +622,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
"src/f32/scalar/vec3a.rs",
ContextBuilder::new_vec3a().build(),
),
(
"src/f32/neon/vec3a.rs",
ContextBuilder::new_vec3a().target_neon().build(),
),
(
"src/f32/sse2/vec3a.rs",
ContextBuilder::new_vec3a().target_sse2().build(),
Expand All @@ -613,6 +639,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
ContextBuilder::new_vec3a().target_coresimd().build(),
),
("src/f32/scalar/vec4.rs", ContextBuilder::new_vec4().build()),
(
"src/f32/neon/vec4.rs",
ContextBuilder::new_vec4().target_neon().build(),
),
(
"src/f32/sse2/vec4.rs",
ContextBuilder::new_vec4().target_sse2().build(),
Expand Down Expand Up @@ -647,6 +677,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
("src/u64/u64vec3.rs", ContextBuilder::new_u64vec3().build()),
("src/u64/u64vec4.rs", ContextBuilder::new_u64vec4().build()),
("src/f32/scalar/quat.rs", ContextBuilder::new_quat().build()),
(
"src/f32/neon/quat.rs",
ContextBuilder::new_quat().target_neon().build(),
),
(
"src/f32/sse2/quat.rs",
ContextBuilder::new_quat().target_sse2().build(),
Expand All @@ -661,6 +695,7 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
),
("src/f64/dquat.rs", ContextBuilder::new_dquat().build()),
("src/f32/scalar/mat2.rs", ContextBuilder::new_mat2().build()),
("src/f32/neon/mat2.rs", ContextBuilder::new_mat2().build()),
(
"src/f32/sse2/mat2.rs",
ContextBuilder::new_mat2().target_sse2().build(),
Expand All @@ -679,6 +714,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
"src/f32/scalar/mat3a.rs",
ContextBuilder::new_mat3a().build(),
),
(
"src/f32/neon/mat3a.rs",
ContextBuilder::new_mat3a().target_neon().build(),
),
(
"src/f32/sse2/mat3a.rs",
ContextBuilder::new_mat3a().target_sse2().build(),
Expand All @@ -692,6 +731,10 @@ pub fn build_output_pairs() -> HashMap<&'static str, tera::Context> {
ContextBuilder::new_mat3a().target_coresimd().build(),
),
("src/f32/scalar/mat4.rs", ContextBuilder::new_mat4().build()),
(
"src/f32/neon/mat4.rs",
ContextBuilder::new_mat4().target_neon().build(),
),
(
"src/f32/sse2/mat4.rs",
ContextBuilder::new_mat4().target_sse2().build(),
Expand Down
96 changes: 96 additions & 0 deletions codegen/templates/mat.rs.tera
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{% import "coresimd.rs.tera" as coresimd %}
{% import "neon.rs.tera" as neon %}
{% import "sse2.rs.tera" as sse2 %}
{% import "wasm32.rs.tera" as wasm32 %}

Expand Down Expand Up @@ -90,6 +91,8 @@ use crate::{
wasm32::*,
{% elif is_coresimd %}
coresimd::*,
{% elif is_neon %}
neon::*,
{% endif %}
{% endif %}
{{ scalar_t }}::math,
Expand All @@ -109,6 +112,8 @@ use core::arch::x86_64::*;
use core::arch::wasm32::*;
{% elif is_coresimd %}
use core::simd::*;
{% elif is_neon and self_t != "Mat2" %}
use core::arch::aarch64::*;
{% endif %}

{% if self_t == "Mat2" and is_sse2 %}
Expand Down Expand Up @@ -1085,6 +1090,26 @@ impl {{ self_t }} {
[0, 2, 6, 6]
)),
}
{% elif self_t == "Mat3A" and is_neon %}
let x = self.x_axis.0;
let y = self.y_axis.0;
let z = self.z_axis.0;
unsafe {
let tmp0 = vreinterpretq_f32_u64(vsetq_lane_u64(
vgetq_lane_u64(vreinterpretq_u64_f32(y), 0),
vreinterpretq_u64_f32(x),
1,
));
let tmp1 = vreinterpretq_f32_u64(vzip2q_u64(
vreinterpretq_u64_f32(x),
vreinterpretq_u64_f32(y),
));
Mat3A::from_cols(
Vec3A::from(vsetq_lane_f32(vgetq_lane_f32(z, 0), vuzp1q_f32(tmp0, z), 3)),
Vec3A::from(vuzp2q_f32(tmp0, vdupq_laneq_f32(z, 1))),
Vec3A::from(vsetq_lane_f32(vgetq_lane_f32(z, 2), vuzp1q_f32(tmp1, z), 2)),
)
}
{% elif self_t == "Mat4" and is_sse2 %}
unsafe {
// Based on https://github.com/microsoft/DirectXMath `XMMatrixTranspose`
Expand Down Expand Up @@ -1285,6 +1310,75 @@ impl {{ self_t }} {
let detcof = addres * f32x4::from_array([1.0, -1.0, 1.0, -1.0]);

dot4(self.x_axis.0, detcof)
{#
// neon implementation is slower than scalar
// {% elif self_t == "Mat4" and is_neon %}
// unsafe {
// let swizz2110 = |x| {
// let x = vuzp1q_f32(x, vdupq_laneq_f32(x, 1));
// vextq_f32(x, x, 1)
// };
// let swizz3323 = |x| {
// let xy = vgetq_lane_f32(x, 3);
// vsetq_lane_f32(xy, vsetq_lane_f32(xy, x, 0), 1)
// };
// let swizz2100 = |x| {
// let y = vuzp1q_f32(x, x);
// vuzp1q_f32(vextq_f32(x, y, 3), y)
// };
// let swizz0021 = |x| vtrn1q_f32(x, vzip1q_f32(x, x));
// // let swizz6723 = |x, y| {
// // vsetq_lane_f64(vgetq_lane_f64(y, 1), 0)
// // };
// let swizz2323 = |x| vreinterpretq_f32_f64(vdupq_laneq_f64(vreinterpretq_f64_f32(x), 1));
// let swizz0012 = |x| vzip1q_f32(x, vuzp1q_f32(x, x));
// let swizz1000 = |x| vsetq_lane_f32(vgetq_lane_f32(x, 1), vdupq_laneq_f32(x, 0), 0);
// let swizz1344 = |x, y| vuzp2q_f32(x, vdupq_laneq_f32(y, 0));
// let swizz0113 = |x| vsetq_lane_f32(vgetq_lane_f32(x, 1), x, 2);
// let swizz2211 = |x| {
// let x = vsetq_lane_f32(vgetq_lane_f32(x, 1), x, 3);
// vzip2q_f32(x, x)
// };
// let swizz2245 = |x, y| vextq_f32(vtrn1q_f32(x, x), y, 1);
// let swizz0233 = |x| vuzp1q_f32(x, vdupq_laneq_f32(x, 3));
// let swizz3332 = |x| vsetq_lane_f32(vgetq_lane_f32(x, 2), vdupq_laneq_f32(x, 3), 3);

// // Based on https://github.com/g-truc/glm `glm_mat4_determinant`
// let swp2a = swizz2110(self.z_axis.0);
// let swp3a = swizz3323(self.w_axis.0);
// let swp2b = swizz3323(self.z_axis.0);
// let swp3b = swizz2110(self.w_axis.0);
// let swp2c = swizz2100(self.z_axis.0);
// let swp3c = swizz0021(self.w_axis.0);

// let mula = vmulq_f32(swp2a, swp3a);
// let mulb = vmulq_f32(swp2b, swp3b);
// let mulc = vmulq_f32(swp2c, swp3c);
// let sube = vsubq_f32(mula, mulb);
// let subf = vsubq_f32(swizz2323(mulc), mulc);

// let subfaca = swizz0012(sube);
// let swpfaca = swizz1000(self.y_axis.0);
// let mulfaca = vmulq_f32(swpfaca, subfaca);

// let subtmpb = swizz1344(sube, subf);
// let subfacb = swizz0113(subtmpb);
// let swpfacb = swizz2211(self.y_axis.0);
// let mulfacb = vmulq_f32(swpfacb, subfacb);

// let subres = vsubq_f32(mulfaca, mulfacb);
// let subtmpc = swizz2245(sube, subf);
// let subfacc = swizz0233(subtmpc);
// let swpfacc = swizz3332(self.y_axis.0);
// let mulfacc = vmulq_f32(swpfacc, subfacc);

// let addres = vaddq_f32(subres, mulfacc);
// const COF: float32x4_t = Vec4::new(1.0, -1.0, 1.0, -1.0).0;
// let detcof = vmulq_f32(addres, COF);

// dot4(self.x_axis.0, detcof)
// }
#}
{% elif dim == 2 %}
self.x_axis.x * self.y_axis.y - self.x_axis.y * self.y_axis.x
{% elif dim == 3 %}
Expand Down Expand Up @@ -1362,6 +1456,8 @@ impl {{ self_t }} {
{{ wasm32::impl_mat4_inverse() }}
{% elif self_t == "Mat4" and is_coresimd %}
{{ coresimd::impl_mat4_inverse() }}
{% elif self_t == "Mat4" and is_neon %}
{{ neon::impl_mat4_inverse() }}
{% elif dim == 2 %}
let inv_det = {
let det = self.determinant();
Expand Down

0 comments on commit ec1a942

Please sign in to comment.