minroot_core/
montgomery.rs1use crate::error::Error;
16use crate::field::{Curve, FieldElement};
17use core::ops;
18
19const MONT_BITS: usize = 128;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct MontgomeryElement {
31 standard: FieldElement,
33}
34
35impl MontgomeryElement {
36 #[must_use]
38 pub fn from_field(a: FieldElement) -> Self {
39 Self { standard: a }
40 }
41
42 #[must_use]
44 pub fn to_field(self) -> FieldElement {
45 self.standard
46 }
47
48 #[must_use]
52 pub fn to_mont_repr(self) -> FieldElement {
53 (0..MONT_BITS).fold(self.standard, |acc, _| acc + acc)
55 }
56
57 #[must_use]
62 pub fn from_mont_repr(mont_repr: FieldElement) -> Self {
63 let standard =
64 (0..MONT_BITS).fold(mont_repr, |acc, _| halve_mod_p(acc));
65 Self { standard }
66 }
67
68 #[must_use]
70 pub fn sqr(self) -> Self {
71 self * self
72 }
73
74 #[must_use]
76 pub fn curve(&self) -> Curve {
77 self.standard.curve()
78 }
79
80 #[must_use]
82 pub fn zero(curve: Curve) -> Self {
83 Self {
84 standard: FieldElement::zero(curve),
85 }
86 }
87
88 #[must_use]
90 pub fn one(curve: Curve) -> Self {
91 Self {
92 standard: FieldElement::one(curve),
93 }
94 }
95
96 #[must_use]
98 pub fn pow(self, exp: &[u64; 4], num_bits: usize) -> Self {
99 Self {
100 standard: self.standard.pow(exp, num_bits),
101 }
102 }
103
104 #[must_use]
106 pub fn fifth_root(self) -> Self {
107 Self {
108 standard: self.standard.fifth_root(),
109 }
110 }
111
112 pub fn from_raw_mont_limbs(
121 limbs: [u64; 4],
122 curve: Curve,
123 ) -> Result<Self, Error> {
124 FieldElement::from_limbs(limbs, curve).map(Self::from_mont_repr)
125 }
126}
127
128impl ops::Mul for MontgomeryElement {
129 type Output = Self;
130
131 fn mul(self, rhs: Self) -> Self {
133 debug_assert_eq!(self.standard.curve(), rhs.standard.curve());
134 Self {
135 standard: self.standard * rhs.standard,
136 }
137 }
138}
139
140#[must_use]
145#[allow(clippy::cast_possible_truncation)]
146fn halve_mod_p(a: FieldElement) -> FieldElement {
147 let limbs = a.limbs();
148 let is_odd = limbs[0] & 1 == 1;
149 let modulus = a.curve().modulus();
150
151 let (words, high_bit) = if is_odd {
154 let mut result = [0u64; 4];
155 let carry =
156 limbs
157 .iter()
158 .zip(modulus.iter())
159 .enumerate()
160 .fold(0u128, |carry, (i, (&ai, &mi))| {
161 let sum = u128::from(ai) + u128::from(mi) + carry;
162 result[i] = sum as u64;
163 sum >> 64
164 });
165 (result, carry as u64)
166 } else {
167 (*limbs, 0u64)
168 };
169
170 let shifted: [u64; 4] = core::array::from_fn(|i| {
172 let current = words[i] >> 1;
173 let from_above = if i + 1 < 4 {
174 words[i + 1] << 63
175 } else {
176 high_bit << 63
177 };
178 current | from_above
179 });
180
181 FieldElement::from_limbs(shifted, a.curve())
183 .unwrap_or_else(|_| FieldElement::zero(a.curve()))
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn roundtrip_to_from_field() {
192 let x = FieldElement::from_u64(42, Curve::Pallas);
193 let m = MontgomeryElement::from_field(x);
194 let back = m.to_field();
195 assert_eq!(back, x);
196 }
197
198 #[test]
199 fn mont_repr_roundtrip() {
200 let x = FieldElement::from_u64(42, Curve::Pallas);
201 let m = MontgomeryElement::from_field(x);
202 let repr = m.to_mont_repr();
203 let recovered = MontgomeryElement::from_mont_repr(repr);
204 assert_eq!(recovered.to_field(), x);
205 }
206
207 #[test]
208 fn mont_mul_matches_field_mul() {
209 let a = FieldElement::from_u64(123, Curve::Pallas);
210 let b = FieldElement::from_u64(456, Curve::Pallas);
211 let expected = a * b;
212
213 let ma = MontgomeryElement::from_field(a);
214 let mb = MontgomeryElement::from_field(b);
215 let result = (ma * mb).to_field();
216 assert_eq!(result, expected);
217 }
218
219 #[test]
220 fn mont_fifth_root_roundtrip() {
221 let x = FieldElement::from_u64(7, Curve::Pallas);
222 let mx = MontgomeryElement::from_field(x);
223 let root = mx.fifth_root();
224 let root5 = root * root * root * root * root;
225 assert_eq!(root5.to_field(), x);
226 }
227
228 #[test]
229 fn mont_repr_of_zero_is_zero() {
230 let z = MontgomeryElement::zero(Curve::Pallas);
231 assert_eq!(z.to_mont_repr(), FieldElement::zero(Curve::Pallas));
232 }
233
234 #[test]
235 fn halve_double_roundtrip() {
236 let x = FieldElement::from_u64(99, Curve::Pallas);
237 let doubled = x + x;
238 let halved = halve_mod_p(doubled);
239 assert_eq!(halved, x);
240 }
241
242 #[test]
243 fn halve_odd_value() {
244 let x = FieldElement::from_u64(7, Curve::Pallas);
246 let half = halve_mod_p(x);
247 assert_eq!(half + half, x);
249 }
250}