Skip to main content

minroot_core/
montgomery.rs

1//! Montgomery form conversion for field elements.
2//!
3//! The hardware operates in Montgomery domain with `R = 2^128`.
4//! This module provides conversion between standard and Montgomery
5//! representations.
6//!
7//! In Montgomery form, an element `a` is represented as `aR mod p`.
8//! Montgomery multiplication: `MontMul(aR, bR) = abR mod p`.
9//!
10//! This reference implementation stores the **standard** value
11//! internally and performs standard arithmetic for correctness.
12//! The Montgomery representation is available via [`MontgomeryElement::to_mont_repr`]
13//! for comparison with hardware outputs.
14
15use crate::error::Error;
16use crate::field::{Curve, FieldElement};
17use core::ops;
18
19/// Number of bits in the Montgomery constant R.
20///
21/// Matches the hardware parameter `LowerTriBits = 128`.
22const MONT_BITS: usize = 128;
23
24/// A field element that tracks both standard and Montgomery forms.
25///
26/// Internally stores the standard value.  Arithmetic is performed
27/// in standard form for correctness.  The Montgomery representation
28/// (`aR mod p`) is available for hardware comparison.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct MontgomeryElement {
31    /// The standard (non-Montgomery) value.
32    standard: FieldElement,
33}
34
35impl MontgomeryElement {
36    /// Wraps a standard field element for Montgomery-domain operations.
37    #[must_use]
38    pub fn from_field(a: FieldElement) -> Self {
39        Self { standard: a }
40    }
41
42    /// Returns the standard (non-Montgomery) field element.
43    #[must_use]
44    pub fn to_field(self) -> FieldElement {
45        self.standard
46    }
47
48    /// Returns the Montgomery representation `aR mod p`.
49    ///
50    /// This is what the hardware stores internally.
51    #[must_use]
52    pub fn to_mont_repr(self) -> FieldElement {
53        // Compute a * 2^128 mod p by repeated doubling.
54        (0..MONT_BITS).fold(self.standard, |acc, _| acc + acc)
55    }
56
57    /// Constructs from a Montgomery representation `aR mod p`.
58    ///
59    /// Recovers the standard value `a` by halving 128 times
60    /// (each halving computes `x * 2^{-1} mod p`).
61    #[must_use]
62    pub fn from_mont_repr(mont_repr: FieldElement) -> Self {
63        let standard =
64            (0..MONT_BITS).fold(mont_repr, |acc, _| halve_mod_p(acc));
65        Self { standard }
66    }
67
68    /// Montgomery squaring.
69    #[must_use]
70    pub fn sqr(self) -> Self {
71        self * self
72    }
73
74    /// Returns the curve.
75    #[must_use]
76    pub fn curve(&self) -> Curve {
77        self.standard.curve()
78    }
79
80    /// The zero element.
81    #[must_use]
82    pub fn zero(curve: Curve) -> Self {
83        Self {
84            standard: FieldElement::zero(curve),
85        }
86    }
87
88    /// The multiplicative identity.
89    #[must_use]
90    pub fn one(curve: Curve) -> Self {
91        Self {
92            standard: FieldElement::one(curve),
93        }
94    }
95
96    /// Modular exponentiation.
97    #[must_use]
98    pub fn pow(self, exp: &[u64; 4], num_bits: usize) -> Self {
99        Self {
100            standard: self.standard.pow(exp, num_bits),
101        }
102    }
103
104    /// Computes the fifth root.
105    #[must_use]
106    pub fn fifth_root(self) -> Self {
107        Self {
108            standard: self.standard.fifth_root(),
109        }
110    }
111
112    /// Constructs from raw limbs in Montgomery representation.
113    ///
114    /// The limbs are interpreted as `aR mod p` and converted to
115    /// the standard value internally.
116    ///
117    /// # Errors
118    ///
119    /// Returns [`Error::OutOfRange`] if the limbs are not less than the modulus.
120    pub fn from_raw_mont_limbs(
121        limbs: [u64; 4],
122        curve: Curve,
123    ) -> Result<Self, Error> {
124        FieldElement::from_limbs(limbs, curve).map(Self::from_mont_repr)
125    }
126}
127
128impl ops::Mul for MontgomeryElement {
129    type Output = Self;
130
131    /// Montgomery multiplication: produces `a * b` in the Montgomery domain.
132    fn mul(self, rhs: Self) -> Self {
133        debug_assert_eq!(self.standard.curve(), rhs.standard.curve());
134        Self {
135            standard: self.standard * rhs.standard,
136        }
137    }
138}
139
140/// Halves a field element modulo p.
141///
142/// Computes `a / 2 mod p`.  If `a` is odd, adds `p` first to make
143/// it even, then shifts right.
144#[must_use]
145#[allow(clippy::cast_possible_truncation)]
146fn halve_mod_p(a: FieldElement) -> FieldElement {
147    let limbs = a.limbs();
148    let is_odd = limbs[0] & 1 == 1;
149    let modulus = a.curve().modulus();
150
151    // If odd, compute (a + p) / 2; if even, compute a / 2.
152    // Since p is odd, a + p is even when a is odd.
153    let (words, high_bit) = if is_odd {
154        let mut result = [0u64; 4];
155        let carry =
156            limbs
157                .iter()
158                .zip(modulus.iter())
159                .enumerate()
160                .fold(0u128, |carry, (i, (&ai, &mi))| {
161                    let sum = u128::from(ai) + u128::from(mi) + carry;
162                    result[i] = sum as u64;
163                    sum >> 64
164                });
165        (result, carry as u64)
166    } else {
167        (*limbs, 0u64)
168    };
169
170    // Shift right by 1
171    let shifted: [u64; 4] = core::array::from_fn(|i| {
172        let current = words[i] >> 1;
173        let from_above = if i + 1 < 4 {
174            words[i + 1] << 63
175        } else {
176            high_bit << 63
177        };
178        current | from_above
179    });
180
181    // Result is guaranteed < p since we started with a < p.
182    FieldElement::from_limbs(shifted, a.curve())
183        .unwrap_or_else(|_| FieldElement::zero(a.curve()))
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn roundtrip_to_from_field() {
192        let x = FieldElement::from_u64(42, Curve::Pallas);
193        let m = MontgomeryElement::from_field(x);
194        let back = m.to_field();
195        assert_eq!(back, x);
196    }
197
198    #[test]
199    fn mont_repr_roundtrip() {
200        let x = FieldElement::from_u64(42, Curve::Pallas);
201        let m = MontgomeryElement::from_field(x);
202        let repr = m.to_mont_repr();
203        let recovered = MontgomeryElement::from_mont_repr(repr);
204        assert_eq!(recovered.to_field(), x);
205    }
206
207    #[test]
208    fn mont_mul_matches_field_mul() {
209        let a = FieldElement::from_u64(123, Curve::Pallas);
210        let b = FieldElement::from_u64(456, Curve::Pallas);
211        let expected = a * b;
212
213        let ma = MontgomeryElement::from_field(a);
214        let mb = MontgomeryElement::from_field(b);
215        let result = (ma * mb).to_field();
216        assert_eq!(result, expected);
217    }
218
219    #[test]
220    fn mont_fifth_root_roundtrip() {
221        let x = FieldElement::from_u64(7, Curve::Pallas);
222        let mx = MontgomeryElement::from_field(x);
223        let root = mx.fifth_root();
224        let root5 = root * root * root * root * root;
225        assert_eq!(root5.to_field(), x);
226    }
227
228    #[test]
229    fn mont_repr_of_zero_is_zero() {
230        let z = MontgomeryElement::zero(Curve::Pallas);
231        assert_eq!(z.to_mont_repr(), FieldElement::zero(Curve::Pallas));
232    }
233
234    #[test]
235    fn halve_double_roundtrip() {
236        let x = FieldElement::from_u64(99, Curve::Pallas);
237        let doubled = x + x;
238        let halved = halve_mod_p(doubled);
239        assert_eq!(halved, x);
240    }
241
242    #[test]
243    fn halve_odd_value() {
244        // 7 / 2 mod p = (7 + p) / 2
245        let x = FieldElement::from_u64(7, Curve::Pallas);
246        let half = halve_mod_p(x);
247        // half + half = 7 mod p
248        assert_eq!(half + half, x);
249    }
250}