1use crate::error::Error;
13use core::ops;
14
15const LIMBS: usize = 4;
17
18const PALLAS_MODULUS: [u64; LIMBS] = [
20 0x992d_30ed_0000_0001,
21 0x2246_98fc_094c_f91b,
22 0x0000_0000_0000_0000,
23 0x4000_0000_0000_0000,
24];
25
26const VESTA_MODULUS: [u64; LIMBS] = [
28 0x8c46_eb21_0000_0001,
29 0x2246_98fc_0994_a8dd,
30 0x0000_0000_0000_0000,
31 0x4000_0000_0000_0000,
32];
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
36pub enum Curve {
37 Pallas,
39 Vesta,
41}
42
43impl Curve {
44 #[must_use]
46 pub fn modulus(self) -> [u64; LIMBS] {
47 match self {
48 Self::Pallas => PALLAS_MODULUS,
49 Self::Vesta => VESTA_MODULUS,
50 }
51 }
52
53 #[must_use]
56 pub fn fifth_root_exponent(self) -> [u64; LIMBS] {
57 match self {
58 Self::Pallas => PALLAS_FIFTH_ROOT_EXP,
59 Self::Vesta => VESTA_FIFTH_ROOT_EXP,
60 }
61 }
62
63 #[must_use]
65 pub fn exponent_bits(self) -> usize {
66 254
68 }
69}
70
71const PALLAS_FIFTH_ROOT_EXP: [u64; LIMBS] = [
75 0xe0f0_f3f0_cccc_cccd,
76 0x4e9e_e0c9_a10a_60e2,
77 0x3333_3333_3333_3333,
78 0x3333_3333_3333_3333,
79];
80
81const VESTA_FIFTH_ROOT_EXP: [u64; LIMBS] = [
85 0xd69f_2280_cccc_cccd,
86 0x4e9e_e0c9_a143_ba4a,
87 0x3333_3333_3333_3333,
88 0x3333_3333_3333_3333,
89];
90
91#[derive(Clone, Copy, PartialEq, Eq, Hash)]
93pub struct FieldElement {
94 limbs: [u64; LIMBS],
95 curve: Curve,
96}
97
98impl core::fmt::Debug for FieldElement {
99 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
100 write!(
101 f,
102 "FieldElement({:?}, 0x{:016x}{:016x}{:016x}{:016x})",
103 self.curve, self.limbs[3], self.limbs[2], self.limbs[1], self.limbs[0],
104 )
105 }
106}
107
108impl FieldElement {
109 #[must_use]
111 pub fn zero(curve: Curve) -> Self {
112 Self {
113 limbs: [0; LIMBS],
114 curve,
115 }
116 }
117
118 #[must_use]
120 pub fn one(curve: Curve) -> Self {
121 Self {
122 limbs: [1, 0, 0, 0],
123 curve,
124 }
125 }
126
127 pub fn from_limbs(limbs: [u64; LIMBS], curve: Curve) -> Result<Self, Error> {
133 let modulus = curve.modulus();
134 if gte_modulus(&limbs, &modulus) {
135 Err(Error::OutOfRange {
136 context: "from_limbs",
137 })
138 } else {
139 Ok(Self { limbs, curve })
140 }
141 }
142
143 #[must_use]
146 pub fn from_u64(val: u64, curve: Curve) -> Self {
147 Self {
148 limbs: [val, 0, 0, 0],
149 curve,
150 }
151 }
152
153 #[must_use]
155 pub fn limbs(&self) -> &[u64; LIMBS] {
156 &self.limbs
157 }
158
159 #[must_use]
161 pub fn curve(&self) -> Curve {
162 self.curve
163 }
164
165 #[must_use]
167 pub fn is_zero(&self) -> bool {
168 self.limbs.iter().all(|&l| l == 0)
169 }
170
171 #[must_use]
173 pub fn sqr(self) -> Self {
174 self * self
175 }
176
177 #[must_use]
182 pub fn pow(self, exp: &[u64; LIMBS], num_bits: usize) -> Self {
183 (0..num_bits).rev().fold(Self::one(self.curve), |acc, i| {
184 let squared = acc.sqr();
185 let limb_idx = i / 64;
186 let bit_idx = i % 64;
187 if (exp[limb_idx] >> bit_idx) & 1 == 1 {
188 squared * self
189 } else {
190 squared
191 }
192 })
193 }
194
195 #[must_use]
197 pub fn fifth_root(self) -> Self {
198 let exp = self.curve.fifth_root_exponent();
199 let bits = self.curve.exponent_bits();
200 self.pow(&exp, bits)
201 }
202
203 #[must_use]
205 pub fn bit(&self, i: usize) -> bool {
206 let limb_idx = i / 64;
207 let bit_idx = i % 64;
208 if limb_idx < LIMBS {
209 (self.limbs[limb_idx] >> bit_idx) & 1 == 1
210 } else {
211 false
212 }
213 }
214}
215
216impl ops::Add for FieldElement {
217 type Output = Self;
218
219 fn add(self, rhs: Self) -> Self {
221 debug_assert_eq!(self.curve, rhs.curve);
222 let modulus = self.curve.modulus();
223 let (sum, carry) = add_limbs(&self.limbs, &rhs.limbs);
224 let result = if carry || gte_modulus(&sum, &modulus) {
225 sub_limbs(&sum, &modulus).0
226 } else {
227 sum
228 };
229 Self {
230 limbs: result,
231 curve: self.curve,
232 }
233 }
234}
235
236impl ops::Sub for FieldElement {
237 type Output = Self;
238
239 fn sub(self, rhs: Self) -> Self {
241 debug_assert_eq!(self.curve, rhs.curve);
242 let modulus = self.curve.modulus();
243 let (diff, borrow) = sub_limbs(&self.limbs, &rhs.limbs);
244 let result = if borrow {
245 add_limbs(&diff, &modulus).0
246 } else {
247 diff
248 };
249 Self {
250 limbs: result,
251 curve: self.curve,
252 }
253 }
254}
255
256impl ops::Mul for FieldElement {
257 type Output = Self;
258
259 fn mul(self, rhs: Self) -> Self {
263 debug_assert_eq!(self.curve, rhs.curve);
264 let wide = mul_wide(&self.limbs, &rhs.limbs);
265 let reduced = reduce_wide(&wide, &self.curve.modulus());
266 Self {
267 limbs: reduced,
268 curve: self.curve,
269 }
270 }
271}
272
273#[allow(clippy::cast_possible_truncation)]
277fn add_limbs(a: &[u64; LIMBS], b: &[u64; LIMBS]) -> ([u64; LIMBS], bool) {
278 let mut result = [0u64; LIMBS];
279 let carry = a.iter().zip(b.iter()).enumerate().fold(
280 0u128,
281 |carry, (i, (&ai, &bi))| {
282 let sum = u128::from(ai) + u128::from(bi) + carry;
283 result[i] = sum as u64;
284 sum >> 64
285 },
286 );
287 (result, carry != 0)
288}
289
290#[allow(clippy::cast_possible_truncation)]
292fn sub_limbs(a: &[u64; LIMBS], b: &[u64; LIMBS]) -> ([u64; LIMBS], bool) {
293 let mut result = [0u64; LIMBS];
294 let borrow = a.iter().zip(b.iter()).enumerate().fold(
295 0u128,
296 |borrow, (i, (&ai, &bi))| {
297 let diff = u128::from(ai).wrapping_sub(u128::from(bi)).wrapping_sub(borrow);
298 result[i] = diff as u64;
299 u128::from(diff >> 127 != 0)
300 },
301 );
302 (result, borrow != 0)
303}
304
305fn gte_modulus(a: &[u64; LIMBS], modulus: &[u64; LIMBS]) -> bool {
307 a.iter()
308 .zip(modulus.iter())
309 .rev()
310 .fold(core::cmp::Ordering::Equal, |ord, (&ai, &mi)| match ord {
311 core::cmp::Ordering::Equal => ai.cmp(&mi),
312 other => other,
313 })
314 != core::cmp::Ordering::Less
315}
316
317#[allow(clippy::cast_possible_truncation)]
319fn mul_wide(a: &[u64; LIMBS], b: &[u64; LIMBS]) -> [u64; LIMBS * 2] {
320 let mut result = [0u64; LIMBS * 2];
321 a.iter().enumerate().for_each(|(i, &ai)| {
322 let carry = b.iter().enumerate().fold(0u128, |carry, (j, &bj)| {
323 let prod =
324 u128::from(ai) * u128::from(bj) + u128::from(result[i + j]) + carry;
325 result[i + j] = prod as u64;
326 prod >> 64
327 });
328 result[i + LIMBS] = carry as u64;
329 });
330 result
331}
332
333fn reduce_wide(wide: &[u64; LIMBS * 2], modulus: &[u64; LIMBS]) -> [u64; LIMBS] {
335 let total_bits = LIMBS * 2 * 64;
336 (0..total_bits).rev().fold([0u64; LIMBS], |acc, bit| {
337 let shifted = shift_left_one(&acc);
339 let limb_idx = bit / 64;
341 let bit_idx = bit % 64;
342 let incoming = (wide[limb_idx] >> bit_idx) & 1;
343 let with_bit = [
344 shifted[0] | incoming,
345 shifted[1],
346 shifted[2],
347 shifted[3],
348 ];
349 if gte_modulus(&with_bit, modulus) {
351 sub_limbs(&with_bit, modulus).0
352 } else {
353 with_bit
354 }
355 })
356}
357
358fn shift_left_one(a: &[u64; LIMBS]) -> [u64; LIMBS] {
360 let mut result = [0u64; LIMBS];
361 (0..LIMBS).rev().for_each(|i| {
362 result[i] = a[i] << 1;
363 if i > 0 {
364 result[i] |= a[i - 1] >> 63;
365 }
366 });
367 result
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn zero_add_identity() {
376 let a = FieldElement::from_u64(42, Curve::Pallas);
377 let z = FieldElement::zero(Curve::Pallas);
378 assert_eq!(a + z, a);
379 assert_eq!(z + a, a);
380 }
381
382 #[test]
383 fn one_mul_identity() {
384 let a = FieldElement::from_u64(12345, Curve::Pallas);
385 let one = FieldElement::one(Curve::Pallas);
386 assert_eq!(a * one, a);
387 assert_eq!(one * a, a);
388 }
389
390 #[test]
391 fn add_sub_roundtrip() {
392 let a = FieldElement::from_u64(100, Curve::Pallas);
393 let b = FieldElement::from_u64(200, Curve::Pallas);
394 assert_eq!((a + b) - b, a);
395 }
396
397 #[test]
398 fn sqr_equals_mul_self() {
399 let a = FieldElement::from_u64(9999, Curve::Pallas);
400 assert_eq!(a.sqr(), a * a);
401 }
402
403 #[test]
404 fn fifth_root_roundtrip() {
405 let x = FieldElement::from_u64(7, Curve::Pallas);
407 let r = x.fifth_root();
408 let r5 = r * r * r * r * r;
409 assert_eq!(r5, x);
410 }
411
412 #[test]
413 fn fifth_root_roundtrip_vesta() {
414 let x = FieldElement::from_u64(13, Curve::Vesta);
415 let r = x.fifth_root();
416 let r5 = r * r * r * r * r;
417 assert_eq!(r5, x);
418 }
419
420 #[test]
421 fn from_limbs_rejects_modulus() {
422 let result = FieldElement::from_limbs(PALLAS_MODULUS, Curve::Pallas);
423 assert!(result.is_err());
424 }
425
426 #[test]
427 fn modulus_minus_one_is_valid() {
428 let mut limbs = PALLAS_MODULUS;
429 limbs[0] -= 1;
430 let result = FieldElement::from_limbs(limbs, Curve::Pallas);
431 assert!(result.is_ok());
432 }
433}