1use crate::error::Error;
17use crate::field::{Curve, FieldElement};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct MinRootState {
22 x: FieldElement,
23 y: FieldElement,
24 iteration: u64,
25}
26
27impl MinRootState {
28 #[must_use]
30 pub fn new(x: FieldElement, y: FieldElement) -> Self {
31 debug_assert_eq!(x.curve(), y.curve());
32 Self {
33 x,
34 y,
35 iteration: 0,
36 }
37 }
38
39 #[must_use]
41 pub fn x(&self) -> FieldElement {
42 self.x
43 }
44
45 #[must_use]
47 pub fn y(&self) -> FieldElement {
48 self.y
49 }
50
51 #[must_use]
53 pub fn iteration(&self) -> u64 {
54 self.iteration
55 }
56
57 #[must_use]
59 pub fn curve(&self) -> Curve {
60 self.x.curve()
61 }
62}
63
64#[must_use]
72pub fn step(state: MinRootState) -> MinRootState {
73 let i_field = FieldElement::from_u64(state.iteration, state.curve());
74 let temp = state.x + state.y;
75 let new_y = state.x + i_field;
76 let new_x = temp.fifth_root();
77 MinRootState {
78 x: new_x,
79 y: new_y,
80 iteration: state.iteration + 1,
81 }
82}
83
84pub fn evaluate(
92 x: FieldElement,
93 y: FieldElement,
94 num_iterations: u64,
95) -> Result<MinRootState, Error> {
96 if num_iterations == 0 {
97 Err(Error::ZeroIterations)
98 } else {
99 let init = MinRootState::new(x, y);
100 Ok((0..num_iterations).fold(init, |state, _| step(state)))
101 }
102}
103
104pub fn evaluate_trace(
112 x: FieldElement,
113 y: FieldElement,
114 num_iterations: u64,
115) -> Result<Vec<MinRootState>, Error> {
116 if num_iterations == 0 {
117 Err(Error::ZeroIterations)
118 } else {
119 let init = MinRootState::new(x, y);
120 Ok((0..num_iterations)
121 .fold(vec![init], |mut trace, _| {
122 let current = trace[trace.len() - 1];
124 trace.push(step(current));
125 trace
126 }))
127 }
128}
129
130pub fn verify(
138 x: FieldElement,
139 y: FieldElement,
140 num_iterations: u64,
141 claimed_x: FieldElement,
142 claimed_y: FieldElement,
143) -> Result<bool, Error> {
144 evaluate(x, y, num_iterations)
145 .map(|result| result.x == claimed_x && result.y == claimed_y)
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 #[test]
153 fn single_step_deterministic() {
154 let x = FieldElement::from_u64(3, Curve::Pallas);
155 let y = FieldElement::from_u64(5, Curve::Pallas);
156 let s1 = step(MinRootState::new(x, y));
157 let s2 = step(MinRootState::new(x, y));
158 assert_eq!(s1, s2);
159 }
160
161 #[test]
162 fn step_modifies_state() {
163 let x = FieldElement::from_u64(3, Curve::Pallas);
164 let y = FieldElement::from_u64(5, Curve::Pallas);
165 let init = MinRootState::new(x, y);
166 let after = step(init);
167 assert_ne!(after.x(), init.x());
169 assert_eq!(after.y(), x);
171 assert_eq!(after.iteration(), 1);
172 }
173
174 #[test]
175 fn fifth_root_consistency() {
176 let x = FieldElement::from_u64(3, Curve::Pallas);
178 let y = FieldElement::from_u64(5, Curve::Pallas);
179 let after = step(MinRootState::new(x, y));
180 let x_prime = after.x();
181 let x5 = x_prime * x_prime * x_prime * x_prime * x_prime;
182 assert_eq!(x5, x + y);
183 }
184
185 #[test]
186 fn evaluate_matches_iterated_step() {
187 let x = FieldElement::from_u64(10, Curve::Pallas);
188 let y = FieldElement::from_u64(20, Curve::Pallas);
189 let n = 3;
190 let eval_result = evaluate(x, y, n);
191 let step_result = (0..n).fold(MinRootState::new(x, y), |s, _| step(s));
192 assert_eq!(
193 eval_result.map(|r| (r.x(), r.y())),
194 Ok((step_result.x(), step_result.y()))
195 );
196 }
197
198 #[test]
199 fn verify_accepts_correct_result() {
200 let x = FieldElement::from_u64(7, Curve::Pallas);
201 let y = FieldElement::from_u64(11, Curve::Pallas);
202 let n = 2;
203 let result = evaluate(x, y, n);
204 assert!(
205 result
206 .iter()
207 .all(|r| verify(x, y, n, r.x(), r.y()) == Ok(true))
208 );
209 }
210
211 #[test]
212 fn verify_rejects_wrong_result() {
213 let x = FieldElement::from_u64(7, Curve::Pallas);
214 let y = FieldElement::from_u64(11, Curve::Pallas);
215 let wrong = FieldElement::from_u64(999, Curve::Pallas);
216 assert_eq!(verify(x, y, 2, wrong, wrong), Ok(false));
217 }
218
219 #[test]
220 fn zero_iterations_is_error() {
221 let x = FieldElement::from_u64(1, Curve::Pallas);
222 let y = FieldElement::from_u64(2, Curve::Pallas);
223 assert!(evaluate(x, y, 0).is_err());
224 }
225
226 #[test]
227 fn trace_has_correct_length() {
228 let x = FieldElement::from_u64(1, Curve::Pallas);
229 let y = FieldElement::from_u64(2, Curve::Pallas);
230 let trace = evaluate_trace(x, y, 3);
231 assert!(trace.iter().all(|t| t.len() == 4));
232 }
233}