Skip to main content

minroot_core/
field.rs

1//! Prime field arithmetic for the Pasta curves (Pallas and Vesta).
2//!
3//! Field elements are 256-bit integers stored as four 64-bit limbs
4//! in little-endian order.  All arithmetic is modular with respect
5//! to the chosen curve's prime modulus.
6//!
7//! # Moduli
8//!
9//! - **Pallas**: `0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001`
10//! - **Vesta**:  `0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001`
11
12use crate::error::Error;
13use core::ops;
14
15/// Number of 64-bit limbs in a field element.
16const LIMBS: usize = 4;
17
18/// The Pallas curve base field modulus, little-endian limbs.
19const PALLAS_MODULUS: [u64; LIMBS] = [
20    0x992d_30ed_0000_0001,
21    0x2246_98fc_094c_f91b,
22    0x0000_0000_0000_0000,
23    0x4000_0000_0000_0000,
24];
25
26/// The Vesta curve base field modulus, little-endian limbs.
27const VESTA_MODULUS: [u64; LIMBS] = [
28    0x8c46_eb21_0000_0001,
29    0x2246_98fc_0994_a8dd,
30    0x0000_0000_0000_0000,
31    0x4000_0000_0000_0000,
32];
33
34/// Identifies which Pasta curve modulus to use.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum Curve {
37    /// The Pallas curve base field.
38    Pallas,
39    /// The Vesta curve base field.
40    Vesta,
41}
42
43impl Curve {
44    /// Returns the modulus limbs for this curve.
45    #[must_use]
46    pub fn modulus(self) -> [u64; LIMBS] {
47        match self {
48            Self::Pallas => PALLAS_MODULUS,
49            Self::Vesta => VESTA_MODULUS,
50        }
51    }
52
53    /// Returns the fifth-root exponent `(4p - 3) / 5` for this curve,
54    /// as little-endian limbs.
55    #[must_use]
56    pub fn fifth_root_exponent(self) -> [u64; LIMBS] {
57        match self {
58            Self::Pallas => PALLAS_FIFTH_ROOT_EXP,
59            Self::Vesta => VESTA_FIFTH_ROOT_EXP,
60        }
61    }
62
63    /// Number of significant bits in the fifth-root exponent.
64    #[must_use]
65    pub fn exponent_bits(self) -> usize {
66        // Both Pallas and Vesta exponents are 254 bits.
67        254
68    }
69}
70
71/// Fifth-root exponent for Pallas: `(4p - 3) / 5`, little-endian limbs.
72///
73/// `0x333333333333333333333333333333334e9ee0c9a10a60e2e0f0f3f0cccccccd`
74const PALLAS_FIFTH_ROOT_EXP: [u64; LIMBS] = [
75    0xe0f0_f3f0_cccc_cccd,
76    0x4e9e_e0c9_a10a_60e2,
77    0x3333_3333_3333_3333,
78    0x3333_3333_3333_3333,
79];
80
81/// Fifth-root exponent for Vesta: `(4p - 3) / 5`, little-endian limbs.
82///
83/// `0x333333333333333333333333333333334e9ee0c9a143ba4ad69f2280cccccccd`
84const VESTA_FIFTH_ROOT_EXP: [u64; LIMBS] = [
85    0xd69f_2280_cccc_cccd,
86    0x4e9e_e0c9_a143_ba4a,
87    0x3333_3333_3333_3333,
88    0x3333_3333_3333_3333,
89];
90
91/// A 256-bit prime field element stored as four little-endian 64-bit limbs.
92#[derive(Clone, Copy, PartialEq, Eq, Hash)]
93pub struct FieldElement {
94    limbs: [u64; LIMBS],
95    curve: Curve,
96}
97
98impl core::fmt::Debug for FieldElement {
99    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100        write!(
101            f,
102            "FieldElement({:?}, 0x{:016x}{:016x}{:016x}{:016x})",
103            self.curve, self.limbs[3], self.limbs[2], self.limbs[1], self.limbs[0],
104        )
105    }
106}
107
108impl FieldElement {
109    /// The additive identity (zero) for the given curve.
110    #[must_use]
111    pub fn zero(curve: Curve) -> Self {
112        Self {
113            limbs: [0; LIMBS],
114            curve,
115        }
116    }
117
118    /// The multiplicative identity (one) for the given curve.
119    #[must_use]
120    pub fn one(curve: Curve) -> Self {
121        Self {
122            limbs: [1, 0, 0, 0],
123            curve,
124        }
125    }
126
127    /// Constructs a field element from little-endian limbs.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`Error::OutOfRange`] if the value is not less than the modulus.
132    pub fn from_limbs(limbs: [u64; LIMBS], curve: Curve) -> Result<Self, Error> {
133        let modulus = curve.modulus();
134        if gte_modulus(&limbs, &modulus) {
135            Err(Error::OutOfRange {
136                context: "from_limbs",
137            })
138        } else {
139            Ok(Self { limbs, curve })
140        }
141    }
142
143    /// Constructs a field element from a single `u64`, placed in the
144    /// lowest limb.
145    #[must_use]
146    pub fn from_u64(val: u64, curve: Curve) -> Self {
147        Self {
148            limbs: [val, 0, 0, 0],
149            curve,
150        }
151    }
152
153    /// Returns the little-endian limb representation.
154    #[must_use]
155    pub fn limbs(&self) -> &[u64; LIMBS] {
156        &self.limbs
157    }
158
159    /// Returns the curve this element belongs to.
160    #[must_use]
161    pub fn curve(&self) -> Curve {
162        self.curve
163    }
164
165    /// Returns `true` if this element is zero.
166    #[must_use]
167    pub fn is_zero(&self) -> bool {
168        self.limbs.iter().all(|&l| l == 0)
169    }
170
171    /// Modular squaring: `self * self mod p`.
172    #[must_use]
173    pub fn sqr(self) -> Self {
174        self * self
175    }
176
177    /// Modular exponentiation via square-and-multiply.
178    ///
179    /// The exponent is given as little-endian limbs with `num_bits`
180    /// significant bits.
181    #[must_use]
182    pub fn pow(self, exp: &[u64; LIMBS], num_bits: usize) -> Self {
183        (0..num_bits).rev().fold(Self::one(self.curve), |acc, i| {
184            let squared = acc.sqr();
185            let limb_idx = i / 64;
186            let bit_idx = i % 64;
187            if (exp[limb_idx] >> bit_idx) & 1 == 1 {
188                squared * self
189            } else {
190                squared
191            }
192        })
193    }
194
195    /// Computes the fifth root: `self^((4p-3)/5) mod p`.
196    #[must_use]
197    pub fn fifth_root(self) -> Self {
198        let exp = self.curve.fifth_root_exponent();
199        let bits = self.curve.exponent_bits();
200        self.pow(&exp, bits)
201    }
202
203    /// Extracts bit `i` from the element (bit 0 is LSB).
204    #[must_use]
205    pub fn bit(&self, i: usize) -> bool {
206        let limb_idx = i / 64;
207        let bit_idx = i % 64;
208        if limb_idx < LIMBS {
209            (self.limbs[limb_idx] >> bit_idx) & 1 == 1
210        } else {
211            false
212        }
213    }
214}
215
216impl ops::Add for FieldElement {
217    type Output = Self;
218
219    /// Modular addition: `self + rhs mod p`.
220    fn add(self, rhs: Self) -> Self {
221        debug_assert_eq!(self.curve, rhs.curve);
222        let modulus = self.curve.modulus();
223        let (sum, carry) = add_limbs(&self.limbs, &rhs.limbs);
224        let result = if carry || gte_modulus(&sum, &modulus) {
225            sub_limbs(&sum, &modulus).0
226        } else {
227            sum
228        };
229        Self {
230            limbs: result,
231            curve: self.curve,
232        }
233    }
234}
235
236impl ops::Sub for FieldElement {
237    type Output = Self;
238
239    /// Modular subtraction: `self - rhs mod p`.
240    fn sub(self, rhs: Self) -> Self {
241        debug_assert_eq!(self.curve, rhs.curve);
242        let modulus = self.curve.modulus();
243        let (diff, borrow) = sub_limbs(&self.limbs, &rhs.limbs);
244        let result = if borrow {
245            add_limbs(&diff, &modulus).0
246        } else {
247            diff
248        };
249        Self {
250            limbs: result,
251            curve: self.curve,
252        }
253    }
254}
255
256impl ops::Mul for FieldElement {
257    type Output = Self;
258
259    /// Modular multiplication: `self * rhs mod p`.
260    ///
261    /// Uses schoolbook multiplication followed by shift-and-subtract reduction.
262    fn mul(self, rhs: Self) -> Self {
263        debug_assert_eq!(self.curve, rhs.curve);
264        let wide = mul_wide(&self.limbs, &rhs.limbs);
265        let reduced = reduce_wide(&wide, &self.curve.modulus());
266        Self {
267            limbs: reduced,
268            curve: self.curve,
269        }
270    }
271}
272
273// ── Multi-limb arithmetic helpers ──────────────────────────────────
274
275/// Adds two 4-limb numbers, returning (result, carry).
276#[allow(clippy::cast_possible_truncation)]
277fn add_limbs(a: &[u64; LIMBS], b: &[u64; LIMBS]) -> ([u64; LIMBS], bool) {
278    let mut result = [0u64; LIMBS];
279    let carry = a.iter().zip(b.iter()).enumerate().fold(
280        0u128,
281        |carry, (i, (&ai, &bi))| {
282            let sum = u128::from(ai) + u128::from(bi) + carry;
283            result[i] = sum as u64;
284            sum >> 64
285        },
286    );
287    (result, carry != 0)
288}
289
290/// Subtracts two 4-limb numbers, returning (result, borrow).
291#[allow(clippy::cast_possible_truncation)]
292fn sub_limbs(a: &[u64; LIMBS], b: &[u64; LIMBS]) -> ([u64; LIMBS], bool) {
293    let mut result = [0u64; LIMBS];
294    let borrow = a.iter().zip(b.iter()).enumerate().fold(
295        0u128,
296        |borrow, (i, (&ai, &bi))| {
297            let diff = u128::from(ai).wrapping_sub(u128::from(bi)).wrapping_sub(borrow);
298            result[i] = diff as u64;
299            u128::from(diff >> 127 != 0)
300        },
301    );
302    (result, borrow != 0)
303}
304
305/// Returns `true` if `a >= modulus`.
306fn gte_modulus(a: &[u64; LIMBS], modulus: &[u64; LIMBS]) -> bool {
307    a.iter()
308        .zip(modulus.iter())
309        .rev()
310        .fold(core::cmp::Ordering::Equal, |ord, (&ai, &mi)| match ord {
311            core::cmp::Ordering::Equal => ai.cmp(&mi),
312            other => other,
313        })
314        != core::cmp::Ordering::Less
315}
316
317/// Schoolbook multiplication producing an 8-limb (512-bit) result.
318#[allow(clippy::cast_possible_truncation)]
319fn mul_wide(a: &[u64; LIMBS], b: &[u64; LIMBS]) -> [u64; LIMBS * 2] {
320    let mut result = [0u64; LIMBS * 2];
321    a.iter().enumerate().for_each(|(i, &ai)| {
322        let carry = b.iter().enumerate().fold(0u128, |carry, (j, &bj)| {
323            let prod =
324                u128::from(ai) * u128::from(bj) + u128::from(result[i + j]) + carry;
325            result[i + j] = prod as u64;
326            prod >> 64
327        });
328        result[i + LIMBS] = carry as u64;
329    });
330    result
331}
332
333/// Reduces a 512-bit product modulo `p` via shift-and-subtract.
334fn reduce_wide(wide: &[u64; LIMBS * 2], modulus: &[u64; LIMBS]) -> [u64; LIMBS] {
335    let total_bits = LIMBS * 2 * 64;
336    (0..total_bits).rev().fold([0u64; LIMBS], |acc, bit| {
337        // Shift accumulator left by 1
338        let shifted = shift_left_one(&acc);
339        // Bring in the current bit from the wide product
340        let limb_idx = bit / 64;
341        let bit_idx = bit % 64;
342        let incoming = (wide[limb_idx] >> bit_idx) & 1;
343        let with_bit = [
344            shifted[0] | incoming,
345            shifted[1],
346            shifted[2],
347            shifted[3],
348        ];
349        // Conditional subtract
350        if gte_modulus(&with_bit, modulus) {
351            sub_limbs(&with_bit, modulus).0
352        } else {
353            with_bit
354        }
355    })
356}
357
358/// Shifts a 4-limb number left by one bit.
359fn shift_left_one(a: &[u64; LIMBS]) -> [u64; LIMBS] {
360    let mut result = [0u64; LIMBS];
361    (0..LIMBS).rev().for_each(|i| {
362        result[i] = a[i] << 1;
363        if i > 0 {
364            result[i] |= a[i - 1] >> 63;
365        }
366    });
367    result
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn zero_add_identity() {
376        let a = FieldElement::from_u64(42, Curve::Pallas);
377        let z = FieldElement::zero(Curve::Pallas);
378        assert_eq!(a + z, a);
379        assert_eq!(z + a, a);
380    }
381
382    #[test]
383    fn one_mul_identity() {
384        let a = FieldElement::from_u64(12345, Curve::Pallas);
385        let one = FieldElement::one(Curve::Pallas);
386        assert_eq!(a * one, a);
387        assert_eq!(one * a, a);
388    }
389
390    #[test]
391    fn add_sub_roundtrip() {
392        let a = FieldElement::from_u64(100, Curve::Pallas);
393        let b = FieldElement::from_u64(200, Curve::Pallas);
394        assert_eq!((a + b) - b, a);
395    }
396
397    #[test]
398    fn sqr_equals_mul_self() {
399        let a = FieldElement::from_u64(9999, Curve::Pallas);
400        assert_eq!(a.sqr(), a * a);
401    }
402
403    #[test]
404    fn fifth_root_roundtrip() {
405        // x^5 should be the inverse of fifth_root for nonzero elements.
406        let x = FieldElement::from_u64(7, Curve::Pallas);
407        let r = x.fifth_root();
408        let r5 = r * r * r * r * r;
409        assert_eq!(r5, x);
410    }
411
412    #[test]
413    fn fifth_root_roundtrip_vesta() {
414        let x = FieldElement::from_u64(13, Curve::Vesta);
415        let r = x.fifth_root();
416        let r5 = r * r * r * r * r;
417        assert_eq!(r5, x);
418    }
419
420    #[test]
421    fn from_limbs_rejects_modulus() {
422        let result = FieldElement::from_limbs(PALLAS_MODULUS, Curve::Pallas);
423        assert!(result.is_err());
424    }
425
426    #[test]
427    fn modulus_minus_one_is_valid() {
428        let mut limbs = PALLAS_MODULUS;
429        limbs[0] -= 1;
430        let result = FieldElement::from_limbs(limbs, Curve::Pallas);
431        assert!(result.is_ok());
432    }
433}