Skip to main content

minroot_core/
polynomial.rs

1//! Redundant polynomial representation for hardware arithmetic.
2//!
3//! The hardware represents 256-bit field elements as polynomials with
4//! [`NUM_COEFFS`] coefficients of [`COEFF_BITS`] bits each.  This
5//! redundant representation avoids long carry chains in the critical
6//! path, enabling higher clock frequencies.
7//!
8//! # Parameters (matching the `SystemVerilog` `mrt_pkg`)
9//!
10//! - `TargetBits = 256`
11//! - `WordBits = 16`
12//! - `RedundantBits = 1`
13//! - `ExtraCoeffs = 1`
14//! - `NumCoeffs = ceil(256 / 16) + 1 = 17`
15//! - `CoeffBits = 16 + 1 = 17`
16
17use crate::error::Error;
18use crate::field::{Curve, FieldElement};
19
20/// Number of data bits per coefficient word.
21pub const WORD_BITS: usize = 16;
22
23/// Number of redundant bits per coefficient (carry absorption).
24pub const REDUNDANT_BITS: usize = 1;
25
26/// Total bits per coefficient.
27pub const COEFF_BITS: usize = WORD_BITS + REDUNDANT_BITS;
28
29/// Extra coefficients beyond `ceil(target_bits / word_bits)`.
30pub const EXTRA_COEFFS: usize = 1;
31
32/// Total number of coefficients per polynomial.
33pub const NUM_COEFFS: usize = 256_usize.div_ceil(WORD_BITS) + EXTRA_COEFFS;
34
35/// Mask for a single coefficient value.
36const COEFF_MASK: u32 = (1 << COEFF_BITS) - 1;
37
38/// A field element in redundant polynomial form.
39///
40/// Each coefficient `c[i]` holds up to [`COEFF_BITS`] bits.
41/// The integer value is `sum(c[i] * 2^(i * WORD_BITS))` for `i` in `0..NUM_COEFFS`.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct PolyElement {
44    coeffs: [u32; NUM_COEFFS],
45    curve: Curve,
46}
47
48impl PolyElement {
49    /// Constructs a polynomial element from an array of coefficients.
50    ///
51    /// Each coefficient must fit in [`COEFF_BITS`] bits.
52    ///
53    /// # Errors
54    ///
55    /// Returns [`Error::OutOfRange`] if any coefficient exceeds the bit width.
56    pub fn from_coeffs(
57        coeffs: [u32; NUM_COEFFS],
58        curve: Curve,
59    ) -> Result<Self, Error> {
60        if coeffs.iter().all(|&c| c <= COEFF_MASK) {
61            Ok(Self { coeffs, curve })
62        } else {
63            Err(Error::OutOfRange {
64                context: "polynomial coefficient exceeds COEFF_BITS",
65            })
66        }
67    }
68
69    /// Converts a [`FieldElement`] into polynomial form.
70    ///
71    /// Extracts [`WORD_BITS`]-wide chunks from the integer representation.
72    #[must_use]
73    pub fn from_field(fe: FieldElement) -> Self {
74        let limbs = fe.limbs();
75        let word_mask = (1u64 << WORD_BITS) - 1;
76        let coeffs = core::array::from_fn(|i| {
77            let bit_offset = i * WORD_BITS;
78            let limb_idx = bit_offset / 64;
79            let bit_idx = bit_offset % 64;
80            if limb_idx < 4 {
81                let val = limbs[limb_idx] >> bit_idx;
82                // Handle crossing a limb boundary
83                let combined = if bit_idx + WORD_BITS > 64 && limb_idx + 1 < 4 {
84                    val | (limbs[limb_idx + 1] << (64 - bit_idx))
85                } else {
86                    val
87                };
88                #[allow(clippy::cast_possible_truncation)]
89                { (combined & word_mask) as u32 }
90            } else {
91                0
92            }
93        });
94        Self {
95            coeffs,
96            curve: fe.curve(),
97        }
98    }
99
100    /// Converts back to a [`FieldElement`].
101    ///
102    /// Performs carry propagation and modular reduction.
103    ///
104    /// # Errors
105    ///
106    /// Returns [`Error::OutOfRange`] if the polynomial's value exceeds the modulus.
107    #[allow(clippy::cast_possible_truncation)]
108    pub fn to_field(self) -> Result<FieldElement, Error> {
109        // Accumulate into a 512-bit intermediate to handle overflow,
110        // then reduce mod p.
111        let mut accum = [0u128; 5];
112
113        self.coeffs.iter().enumerate().for_each(|(i, &c)| {
114            let bit_offset = i * WORD_BITS;
115            let chunk_idx = bit_offset / 128;
116            let chunk_bit = bit_offset % 128;
117            accum[chunk_idx] += u128::from(c) << chunk_bit;
118        });
119
120        // Propagate carries across chunks into 64-bit limbs
121        let mut limbs = [0u64; 4];
122        let _ = accum.iter().enumerate().fold(0u128, |carry, (i, &val)| {
123            let total = val + carry;
124            let base = i * 2;
125            if base < 4 {
126                limbs[base] = total as u64;
127                if base + 1 < 4 {
128                    limbs[base + 1] = (total >> 64) as u64;
129                    0
130                } else {
131                    total >> 64
132                }
133            } else {
134                total
135            }
136        });
137
138        FieldElement::from_limbs(limbs, self.curve).map_err(|_| Error::OutOfRange {
139            context: "polynomial to_field: value exceeds modulus after carry propagation",
140        })
141    }
142
143    /// Returns the coefficient array.
144    #[must_use]
145    pub fn coeffs(&self) -> &[u32; NUM_COEFFS] {
146        &self.coeffs
147    }
148
149    /// Returns the curve.
150    #[must_use]
151    pub fn curve(&self) -> Curve {
152        self.curve
153    }
154
155    /// Coefficient-wise addition without carry propagation.
156    ///
157    /// This is the hardware-friendly operation: coefficients may
158    /// temporarily exceed [`WORD_BITS`], using the redundant bit.
159    #[must_use]
160    pub fn add_no_reduce(self, rhs: Self) -> Self {
161        debug_assert_eq!(self.curve, rhs.curve);
162        let coeffs = core::array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]);
163        Self {
164            coeffs,
165            curve: self.curve,
166        }
167    }
168
169    /// Zero polynomial.
170    #[must_use]
171    pub fn zero(curve: Curve) -> Self {
172        Self {
173            coeffs: [0; NUM_COEFFS],
174            curve,
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn roundtrip_small_value() {
185        let fe = FieldElement::from_u64(0xDEAD_BEEF, Curve::Pallas);
186        let poly = PolyElement::from_field(fe);
187        let back = poly.to_field();
188        assert_eq!(back, Ok(fe));
189    }
190
191    #[test]
192    fn roundtrip_large_value() {
193        // Use a value near the modulus but below it
194        let fe = FieldElement::from_limbs(
195            [
196                0x992d_30ec_ffff_ffff,
197                0x2246_98fc_094c_f91a,
198                0,
199                0x3fff_ffff_ffff_ffff,
200            ],
201            Curve::Pallas,
202        );
203        assert!(
204            fe.iter().all(|fe| {
205                let poly = PolyElement::from_field(*fe);
206                poly.to_field() == Ok(*fe)
207            })
208        );
209    }
210
211    #[test]
212    fn zero_roundtrip() {
213        let fe = FieldElement::zero(Curve::Pallas);
214        let poly = PolyElement::from_field(fe);
215        assert_eq!(poly.to_field(), Ok(fe));
216    }
217
218    #[test]
219    fn coefficient_count() {
220        assert_eq!(NUM_COEFFS, 17);
221    }
222
223    #[test]
224    fn coeff_bits_correct() {
225        assert_eq!(COEFF_BITS, 17);
226    }
227}