Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bign256: WideFieldElement #920

Open
makavity opened this issue Aug 3, 2023 · 11 comments
Open

bign256: WideFieldElement #920

makavity opened this issue Aug 3, 2023 · 11 comments

Comments

@makavity
Copy link
Contributor

makavity commented Aug 3, 2023

Hello! For implementation of 6.2.3 point 2 of STB 34.101.66-2014 I need to construct FieldElement from 48 bytes.
I took the implementation of wide arithmetic from k256 crate:

wide64.rs
use super::{FieldElement, MODULUS_WORDS};
use elliptic_curve::{
    bigint::{Limb, U256, U512},
    subtle::{Choice, ConditionallySelectable},
};
use crate::arithmetic::field::MODULUS;

/// Constant representing the modulus
/// p = 2^{256} − 189
pub(crate) const MODULUS: U256 =
    U256::from_be_hex("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF43");
    
const MODULUS_WORDS: [Word; U256::LIMBS] = MODULUS.to_words();

const NEG_MODULUS: [u64; 4] = [
    !MODULUS_WORDS[0] + 1,
    !MODULUS_WORDS[1],
    !MODULUS_WORDS[2],
    !MODULUS_WORDS[3],
];

#[derive(Clone, Copy, Debug, Default)]
pub struct WideFieldElement(pub(super) U512);

impl WideFieldElement {
    pub const fn from_bytes(bytes: &[u8; 64]) -> Self {
        Self(U512::from_le_slice(bytes))
    }

    // #[inline(always)] // only used in Scalar::mul(), so won't cause binary bloat
    pub fn mul_wide(a: &FieldElement, b: &FieldElement) -> Self {
        let a = a.0.to_words();
        let b = b.0.to_words();

        // 160 bit accumulator.
        let c0 = 0;
        let c1 = 0;
        let c2 = 0;

        // l[0..7] = a[0..3] * b[0..3].
        let (c0, c1) = muladd_fast(a[0], b[0], c0, c1);
        let (l0, c0, c1) = (c0, c1, 0);
        let (c0, c1, c2) = muladd(a[0], b[1], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[1], b[0], c0, c1, c2);
        let (l1, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = muladd(a[0], b[2], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[1], b[1], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[2], b[0], c0, c1, c2);
        let (l2, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = muladd(a[0], b[3], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[1], b[2], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[2], b[1], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[3], b[0], c0, c1, c2);
        let (l3, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = muladd(a[1], b[3], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[2], b[2], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[3], b[1], c0, c1, c2);
        let (l4, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = muladd(a[2], b[3], c0, c1, c2);
        let (c0, c1, c2) = muladd(a[3], b[2], c0, c1, c2);
        let (l5, c0, c1, _c2) = (c0, c1, c2, 0);
        let (c0, c1) = muladd_fast(a[3], b[3], c0, c1);
        let (l6, c0, _c1) = (c0, c1, 0);
        let l7 = c0;

        Self(U512::from_words([l0, l1, l2, l3, l4, l5, l6, l7]))
    }

    /// Multiplies `a` by `b` (without modulo reduction) divide the result by `2^shift`
    /// (rounding to the nearest integer).
    /// Variable time in `shift`.
    pub(crate) fn mul_shift_vartime(a: &FieldElement, b: &FieldElement, shift: usize) -> FieldElement {
        debug_assert!(shift >= 256);

        let l = Self::mul_wide(a, b).0.to_words();
        let shiftlimbs = shift >> 6;
        let shiftlow = shift & 0x3F;
        let shifthigh = 64 - shiftlow;

        let r0 = if shift < 512 {
            let lo = l[shiftlimbs] >> shiftlow;
            let hi = if shift < 448 && shiftlow != 0 {
                l[1 + shiftlimbs] << shifthigh
            } else {
                0
            };
            hi | lo
        } else {
            0
        };

        let r1 = if shift < 448 {
            let lo = l[1 + shiftlimbs] >> shiftlow;
            let hi = if shift < 384 && shiftlow != 0 {
                l[2 + shiftlimbs] << shifthigh
            } else {
                0
            };
            hi | lo
        } else {
            0
        };

        let r2 = if shift < 384 {
            let lo = l[2 + shiftlimbs] >> shiftlow;
            let hi = if shift < 320 && shiftlow != 0 {
                l[3 + shiftlimbs] << shifthigh
            } else {
                0
            };
            hi | lo
        } else {
            0
        };

        let r3 = if shift < 320 {
            l[3 + shiftlimbs] >> shiftlow
        } else {
            0
        };

        let res = FieldElement(U256::from_words([r0, r1, r2, r3]));

        // Check the highmost discarded bit and round up if it is set.
        let c = (l[(shift - 1) >> 6] >> ((shift - 1) & 0x3f)) & 1;
        FieldElement::conditional_select(&res, &res.add(&FieldElement::ONE), Choice::from(c as u8))
    }

    fn reduce_impl(&self, modulus_minus_one: bool) -> FieldElement {
        let neg_modulus0 = if modulus_minus_one {
            NEG_MODULUS[0] + 1
        } else {
            NEG_MODULUS[0]
        };
        let modulus = if modulus_minus_one {
            MODULUS.wrapping_sub(&U256::ONE)
        } else {
            MODULUS
        };

        let w = self.0.to_words();
        let n0 = w[4];
        let n1 = w[5];
        let n2 = w[6];
        let n3 = w[7];

        // Reduce 512 bits into 385.
        // m[0..6] = self[0..3] + n[0..3] * neg_modulus.
        let c0 = w[0];
        let c1 = 0;
        let c2 = 0;
        let (c0, c1) = muladd_fast(n0, neg_modulus0, c0, c1);
        let (m0, c0, c1) = (c0, c1, 0);
        let (c0, c1) = sumadd_fast(w[1], c0, c1);
        let (c0, c1, c2) = muladd(n1, neg_modulus0, c0, c1, c2);
        let (c0, c1, c2) = muladd(n0, NEG_MODULUS[1], c0, c1, c2);
        let (m1, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = sumadd(w[2], c0, c1, c2);
        let (c0, c1, c2) = muladd(n2, neg_modulus0, c0, c1, c2);
        let (c0, c1, c2) = muladd(n1, NEG_MODULUS[1], c0, c1, c2);
        let (c0, c1, c2) = sumadd(n0, c0, c1, c2);
        let (m2, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = sumadd(w[3], c0, c1, c2);
        let (c0, c1, c2) = muladd(n3, neg_modulus0, c0, c1, c2);
        let (c0, c1, c2) = muladd(n2, NEG_MODULUS[1], c0, c1, c2);
        let (c0, c1, c2) = sumadd(n1, c0, c1, c2);
        let (m3, c0, c1, c2) = (c0, c1, c2, 0);
        let (c0, c1, c2) = muladd(n3, NEG_MODULUS[1], c0, c1, c2);
        let (c0, c1, c2) = sumadd(n2, c0, c1, c2);
        let (m4, c0, c1, _c2) = (c0, c1, c2, 0);
        let (c0, c1) = sumadd_fast(n3, c0, c1);
        let (m5, c0, _c1) = (c0, c1, 0);
        debug_assert!(c0 <= 1);
        let m6 = c0;

        // Reduce 385 bits into 258.
        // p[0..4] = m[0..3] + m[4..6] * neg_modulus.
        let c0 = m0;
        let c1 = 0;
        let c2 = 0;
        let (c0, c1) = muladd_fast(m4, neg_modulus0, c0, c1);
        let (p0, c0, c1) = (c0, c1, 0);
        let (c0, c1) = sumadd_fast(m1, c0, c1);
        let (c0, c1, c2) = muladd(m5, neg_modulus0, c0, c1, c2);
        let (c0, c1, c2) = muladd(m4, NEG_MODULUS[1], c0, c1, c2);
        let (p1, c0, c1) = (c0, c1, 0);
        let (c0, c1, c2) = sumadd(m2, c0, c1, c2);
        let (c0, c1, c2) = muladd(m6, neg_modulus0, c0, c1, c2);
        let (c0, c1, c2) = muladd(m5, NEG_MODULUS[1], c0, c1, c2);
        let (c0, c1, c2) = sumadd(m4, c0, c1, c2);
        let (p2, c0, c1, _c2) = (c0, c1, c2, 0);
        let (c0, c1) = sumadd_fast(m3, c0, c1);
        let (c0, c1) = muladd_fast(m6, NEG_MODULUS[1], c0, c1);
        let (c0, c1) = sumadd_fast(m5, c0, c1);
        let (p3, c0, _c1) = (c0, c1, 0);
        let p4 = c0 + m6;
        debug_assert!(p4 <= 2);

        // Reduce 258 bits into 256.
        // r[0..3] = p[0..3] + p[4] * neg_modulus.
        let mut c = (p0 as u128) + (neg_modulus0 as u128) * (p4 as u128);
        let r0 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
        c >>= 64;
        c += (p1 as u128) + (NEG_MODULUS[1] as u128) * (p4 as u128);
        let r1 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
        c >>= 64;
        c += (p2 as u128) + (p4 as u128);
        let r2 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
        c >>= 64;
        c += p3 as u128;
        let r3 = (c & 0xFFFFFFFFFFFFFFFFu128) as u64;
        c >>= 64;

        // Final reduction of r.
        let r = U256::from([r0, r1, r2, r3]);
        let (r2, underflow) = r.sbb(&modulus, Limb::ZERO);
        let high_bit = Choice::from(c as u8);
        let underflow = Choice::from((underflow.0 >> 63) as u8);
        FieldElement(U256::conditional_select(&r, &r2, !underflow | high_bit))
    }

    #[inline(always)] // only used in Scalar::mul(), so won't cause binary bloat
    pub(super) fn reduce(&self) -> FieldElement {
        self.reduce_impl(false)
    }

    pub(super) fn reduce_nonzero(&self) -> FieldElement {
        self.reduce_impl(true) + FieldElement::ONE
    }
}

/// Constant-time comparison.
#[inline(always)]
fn ct_less(a: u64, b: u64) -> u64 {
    // Do not convert to Choice since it is only used internally,
    // and we don't want loss of performance.
    (a < b) as u64
}

/// Add a to the number defined by (c0,c1,c2). c2 must never overflow.
fn sumadd(a: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
    let new_c0 = c0.wrapping_add(a); // overflow is handled on the next line
    let over = ct_less(new_c0, a);
    let new_c1 = c1.wrapping_add(over); // overflow is handled on the next line
    let new_c2 = c2 + ct_less(new_c1, over); // never overflows by contract
    (new_c0, new_c1, new_c2)
}

/// Add a to the number defined by (c0,c1). c1 must never overflow, c2 must be zero.
fn sumadd_fast(a: u64, c0: u64, c1: u64) -> (u64, u64) {
    let new_c0 = c0.wrapping_add(a); // overflow is handled on the next line
    let new_c1 = c1 + ct_less(new_c0, a); // never overflows by contract (verified the next line)
    debug_assert!((new_c1 != 0) | (new_c0 >= a));
    (new_c0, new_c1)
}

/// Add a*b to the number defined by (c0,c1,c2). c2 must never overflow.
fn muladd(a: u64, b: u64, c0: u64, c1: u64, c2: u64) -> (u64, u64, u64) {
    let t = (a as u128) * (b as u128);
    let th = (t >> 64) as u64; // at most 0xFFFFFFFFFFFFFFFE
    let tl = t as u64;

    let new_c0 = c0.wrapping_add(tl); // overflow is handled on the next line
    let new_th = th + u64::from(new_c0 < tl); // at most 0xFFFFFFFFFFFFFFFF
    let new_c1 = c1.wrapping_add(new_th); // overflow is handled on the next line
    let new_c2 = c2 + ct_less(new_c1, new_th); // never overflows by contract (verified in the next line)
    debug_assert!((new_c1 >= new_th) || (new_c2 != 0));
    (new_c0, new_c1, new_c2)
}

/// Add a*b to the number defined by (c0,c1). c1 must never overflow.
fn muladd_fast(a: u64, b: u64, c0: u64, c1: u64) -> (u64, u64) {
    let t = (a as u128) * (b as u128);
    let th = (t >> 64) as u64; // at most 0xFFFFFFFFFFFFFFFE
    let tl = t as u64;

    let new_c0 = c0.wrapping_add(tl); // overflow is handled on the next line
    let new_th = th + ct_less(new_c0, tl); // at most 0xFFFFFFFFFFFFFFFF
    let new_c1 = c1 + new_th; // never overflows by contract (verified in the next line)
    debug_assert!(new_c1 >= new_th);
    (new_c0, new_c1)
}

My tests is:

let two = FieldElement::ONE + FieldElement::ONE;
let one = two * FieldElement::TWO_INV;
println!("1 (montgomery): {:02X?}", one);
println!("1 (canonical): {:02X?}", one.to_canonical());

let one_wide = WideFieldElement::mul_wide(&two, &FieldElement::TWO_INV);
println!("1 (wide montgomery): {:02X?}", one_wide);
println!("1 (wide reduced): {:02X?}", one_wide.reduce());
println!("1 (wide reduced canonical): {:02X?}", one_wide.reduce().to_canonical());

Output is:

1 (montgomery): FieldElement(Uint(0x00000000000000000000000000000000000000000000000000000000000000BD))
1 (canonical): Uint(0x0000000000000000000000000000000000000000000000000000000000000001)
1 (wide montgomery): WideFieldElement(Uint(0x00000000000000000000000000000000000000000000000000000000000000BD0000000000000000000000000000000000000000000000000000000000000000))
1 (wide reduced): FieldElement(Uint(0x000000000000000000000000000000BD00000000000000000000000000008B89))
1 (wide reduced canonical): Uint(0x00000000000000000000000000000001000000000000000000000000000000BD)

In my opinion, 1 (wide reduced canonical) and 1 (canonical) should be the same, but 1 (wide reduced canonical) is in Montgomery form. Don't know, what am I doing wrong.
Can I get help with that?
Thanks!

@makavity
Copy link
Contributor Author

@tarcieri hey. Any idea about this?

@tarcieri
Copy link
Member

Not sure

@makavity
Copy link
Contributor Author

Maybe any idea where I can read about this algorithm implementation?

@tarcieri
Copy link
Member

Perhaps @fjarri can help you, as he wrote it

@fjarri
Copy link
Contributor

fjarri commented Aug 28, 2023

@makavity it's not quite clear to me what's happening here. The title refers to bign256, the code for wide reduction you quoted is from k256 and used for curve scalars, not field elements (field elements use lazy reduction), and WideFieldElement is nowhere to be found in master. Could you provide an MRE?

Also I can't open the link to STB 34.101.66-2014 from the top message.

@makavity
Copy link
Contributor Author

@fjarri so, I can't use same code for implementation of WideFieldElement, right?
You can't open it, bcs it is only accessible from Belarus, sorry. Try this one link please.
bake-spec19.pdf

@makavity
Copy link
Contributor Author

makavity commented Aug 28, 2023

Also, I have a question. Is that variables - scalars of field elements? Can't assume it.
telegram-cloud-photo-size-2-5433959436143152812-y
I suppose, it is field elements, because:
image

@fjarri
Copy link
Contributor

fjarri commented Aug 28, 2023

So the problem has nothing to do with bign256, and you're trying to implement a new curve? I would suggest you to just use the standard tools available in crypto-bigint for starters and get that working. Then you can try to modify k256's optimized operations for your purpose, if your modulus allows it (could be possible if it also has the form 2^uint_bits - c where c is a small number). I can help you with specific issues, but I need something I can actually execute on my side.

@makavity
Copy link
Contributor Author

I am trying to implement wide operations for bign256, because I need to to implement swu algorithm.
Okay, thank you, i'll take a look at crypto-bigint.
Thank you for help, if I need more help - i will ask you.

@fjarri
Copy link
Contributor

fjarri commented Aug 28, 2023

Ah, I see. Both Field and Scalar in bign256 are wrappers around crypto_bigint::U256, so I think using crypto_bigint stuff as a first approximation is a good approach. The modulus seems very convenient for optimizations, so that may be possible (note that crypto_bigint has a few operations with the preifx _special that are specifically designed for a modulus of this form).

@makavity
Copy link
Contributor Author

makavity commented Sep 19, 2023

@fjarri didn't find a better method, than this:

    pub fn mul_wide(a: &FieldElement, b: &FieldElement) -> Self {
        let a_w = a.0.as_words();
        let b_w = b.0.as_words();

        let lhs = U512::from_words([a_w[0], a_w[1], a_w[2], a_w[3], 0, 0, 0, 0]);
        let rhs = U512::from_words([b_w[0], b_w[1], b_w[2], b_w[3], 0, 0, 0, 0]);

        Self(lhs.wrapping_mul(&rhs))
    }

    fn reduce_impl(&self, _modulus_minus_one: bool) -> FieldElement {
        let m = MODULUS.as_words();
        let p = U512::from_words([m[0], m[1], m[2], m[3], 0, 0, 0, 0]);

        let res = self.0.const_rem(&p).0.to_words();

        FieldElement(U256::from_words([res[0], res[1], res[2], res[3]]))
    }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@tarcieri @fjarri @makavity and others