minroot_core/
polynomial.rs1use crate::error::Error;
18use crate::field::{Curve, FieldElement};
19
20pub const WORD_BITS: usize = 16;
22
23pub const REDUNDANT_BITS: usize = 1;
25
26pub const COEFF_BITS: usize = WORD_BITS + REDUNDANT_BITS;
28
29pub const EXTRA_COEFFS: usize = 1;
31
32pub const NUM_COEFFS: usize = 256_usize.div_ceil(WORD_BITS) + EXTRA_COEFFS;
34
35const COEFF_MASK: u32 = (1 << COEFF_BITS) - 1;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
43pub struct PolyElement {
44 coeffs: [u32; NUM_COEFFS],
45 curve: Curve,
46}
47
48impl PolyElement {
49 pub fn from_coeffs(
57 coeffs: [u32; NUM_COEFFS],
58 curve: Curve,
59 ) -> Result<Self, Error> {
60 if coeffs.iter().all(|&c| c <= COEFF_MASK) {
61 Ok(Self { coeffs, curve })
62 } else {
63 Err(Error::OutOfRange {
64 context: "polynomial coefficient exceeds COEFF_BITS",
65 })
66 }
67 }
68
69 #[must_use]
73 pub fn from_field(fe: FieldElement) -> Self {
74 let limbs = fe.limbs();
75 let word_mask = (1u64 << WORD_BITS) - 1;
76 let coeffs = core::array::from_fn(|i| {
77 let bit_offset = i * WORD_BITS;
78 let limb_idx = bit_offset / 64;
79 let bit_idx = bit_offset % 64;
80 if limb_idx < 4 {
81 let val = limbs[limb_idx] >> bit_idx;
82 let combined = if bit_idx + WORD_BITS > 64 && limb_idx + 1 < 4 {
84 val | (limbs[limb_idx + 1] << (64 - bit_idx))
85 } else {
86 val
87 };
88 #[allow(clippy::cast_possible_truncation)]
89 { (combined & word_mask) as u32 }
90 } else {
91 0
92 }
93 });
94 Self {
95 coeffs,
96 curve: fe.curve(),
97 }
98 }
99
100 #[allow(clippy::cast_possible_truncation)]
108 pub fn to_field(self) -> Result<FieldElement, Error> {
109 let mut accum = [0u128; 5];
112
113 self.coeffs.iter().enumerate().for_each(|(i, &c)| {
114 let bit_offset = i * WORD_BITS;
115 let chunk_idx = bit_offset / 128;
116 let chunk_bit = bit_offset % 128;
117 accum[chunk_idx] += u128::from(c) << chunk_bit;
118 });
119
120 let mut limbs = [0u64; 4];
122 let _ = accum.iter().enumerate().fold(0u128, |carry, (i, &val)| {
123 let total = val + carry;
124 let base = i * 2;
125 if base < 4 {
126 limbs[base] = total as u64;
127 if base + 1 < 4 {
128 limbs[base + 1] = (total >> 64) as u64;
129 0
130 } else {
131 total >> 64
132 }
133 } else {
134 total
135 }
136 });
137
138 FieldElement::from_limbs(limbs, self.curve).map_err(|_| Error::OutOfRange {
139 context: "polynomial to_field: value exceeds modulus after carry propagation",
140 })
141 }
142
143 #[must_use]
145 pub fn coeffs(&self) -> &[u32; NUM_COEFFS] {
146 &self.coeffs
147 }
148
149 #[must_use]
151 pub fn curve(&self) -> Curve {
152 self.curve
153 }
154
155 #[must_use]
160 pub fn add_no_reduce(self, rhs: Self) -> Self {
161 debug_assert_eq!(self.curve, rhs.curve);
162 let coeffs = core::array::from_fn(|i| self.coeffs[i] + rhs.coeffs[i]);
163 Self {
164 coeffs,
165 curve: self.curve,
166 }
167 }
168
169 #[must_use]
171 pub fn zero(curve: Curve) -> Self {
172 Self {
173 coeffs: [0; NUM_COEFFS],
174 curve,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn roundtrip_small_value() {
185 let fe = FieldElement::from_u64(0xDEAD_BEEF, Curve::Pallas);
186 let poly = PolyElement::from_field(fe);
187 let back = poly.to_field();
188 assert_eq!(back, Ok(fe));
189 }
190
191 #[test]
192 fn roundtrip_large_value() {
193 let fe = FieldElement::from_limbs(
195 [
196 0x992d_30ec_ffff_ffff,
197 0x2246_98fc_094c_f91a,
198 0,
199 0x3fff_ffff_ffff_ffff,
200 ],
201 Curve::Pallas,
202 );
203 assert!(
204 fe.iter().all(|fe| {
205 let poly = PolyElement::from_field(*fe);
206 poly.to_field() == Ok(*fe)
207 })
208 );
209 }
210
211 #[test]
212 fn zero_roundtrip() {
213 let fe = FieldElement::zero(Curve::Pallas);
214 let poly = PolyElement::from_field(fe);
215 assert_eq!(poly.to_field(), Ok(fe));
216 }
217
218 #[test]
219 fn coefficient_count() {
220 assert_eq!(NUM_COEFFS, 17);
221 }
222
223 #[test]
224 fn coeff_bits_correct() {
225 assert_eq!(COEFF_BITS, 17);
226 }
227}