tinc/private/
oneof.rs

1use std::marker::PhantomData;
2
3use serde::de::{Unexpected, VariantAccess};
4
5use super::{
6    DeserializeContent, DeserializeHelper, Expected, IdentifiedValue, Identifier, IdentifierDeserializer, IdentifierFor,
7    MapAccessValueDeserializer, SerdeDeserializer, SerdePathToken, TrackedError, Tracker, TrackerDeserializer, TrackerFor,
8    TrackerWrapper, report_de_error, report_tracked_error, set_irrecoverable,
9};
10
11pub trait OneOfHelper {
12    type Target;
13}
14
15impl<T> OneOfHelper for Option<T> {
16    type Target = T;
17}
18
19pub trait TaggedOneOfIdentifier: Identifier {
20    const TAG: Self;
21    const CONTENT: Self;
22}
23
24pub trait TrackerDeserializeIdentifier<'de>: Tracker
25where
26    Self::Target: IdentifierFor,
27{
28    fn deserialize<D>(
29        &mut self,
30        value: &mut Self::Target,
31        identifier: <Self::Target as IdentifierFor>::Identifier,
32        deserializer: D,
33    ) -> Result<(), D::Error>
34    where
35        D: DeserializeContent<'de>;
36}
37
38pub trait TrackedOneOfVariant {
39    type Variant: Identifier;
40}
41
42pub trait TrackedOneOfDeserializer<'de>: TrackerFor + IdentifierFor + TrackedOneOfVariant + Sized
43where
44    Self::Tracker: TrackerWrapper,
45{
46    const DENY_UNKNOWN_FIELDS: bool = false;
47
48    fn deserialize<D>(
49        value: &mut Option<Self>,
50        identifier: Self::Variant,
51        tracker: &mut Option<<Self::Tracker as TrackerWrapper>::Tracker>,
52        deserializer: D,
53    ) -> Result<(), D::Error>
54    where
55        D: DeserializeContent<'de>;
56
57    fn tracker_to_identifier(tracker: &<Self::Tracker as TrackerWrapper>::Tracker) -> Self::Variant;
58    fn value_to_identifier(value: &Self) -> Self::Variant;
59}
60
61impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, TaggedOneOfTracker<T>>
62where
63    T: Tracker,
64    T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
65    T::Target: IdentifierFor,
66    <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
67{
68    type Value = ();
69
70    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
71        <T::Target as Expected>::expecting(formatter)
72    }
73
74    fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
75    where
76        A: serde::de::MapAccess<'de>,
77    {
78        while let Some(key) = map
79            .next_key_seed(IdentifierDeserializer::<<T::Target as IdentifierFor>::Identifier>::new())
80            .inspect_err(|_| {
81                set_irrecoverable();
82            })?
83        {
84            let _token = SerdePathToken::push_field(match &key {
85                IdentifiedValue::Found(tag) => tag.name(),
86                IdentifiedValue::Unknown(v) => v.as_ref(),
87            });
88
89            let mut deserialized = false;
90
91            match &key {
92                IdentifiedValue::Found(tag) => {
93                    TrackerDeserializeIdentifier::deserialize(
94                        self.tracker,
95                        self.value,
96                        *tag,
97                        MapAccessValueDeserializer {
98                            map: &mut map,
99                            deserialized: &mut deserialized,
100                        },
101                    )?;
102                }
103                IdentifiedValue::Unknown(_) => {
104                    report_tracked_error(TrackedError::unknown_field(T::Target::DENY_UNKNOWN_FIELDS))?;
105                }
106            }
107
108            if !deserialized {
109                map.next_value::<serde::de::IgnoredAny>().inspect_err(|_| {
110                    set_irrecoverable();
111                })?;
112            }
113        }
114
115        Ok(())
116    }
117}
118
119impl<'de, T> TrackerDeserializer<'de> for OneOfTracker<T>
120where
121    T: Tracker,
122    T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
123{
124    fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
125    where
126        D: DeserializeContent<'de>,
127    {
128        deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
129    }
130}
131
132pub struct TrackerForOneOf<T>(PhantomData<T>);
133
134impl<T: TrackerFor> TrackerFor for TrackerForOneOf<T> {
135    type Tracker = OneOfTracker<T::Tracker>;
136}
137
138const TAGGED_ONE_OF_TRACKER_STATE_TAG_INVALID: u8 = 0b00000001;
139const TAGGED_ONE_OF_TRACKER_STATE_HAS_CONTENT: u8 = 0b00000010;
140
141pub struct TaggedOneOfTracker<T>
142where
143    T: Tracker,
144    T::Target: TrackedOneOfVariant,
145{
146    tracker: Option<T>,
147    state: u8,
148    tag_buffer: Option<<T::Target as TrackedOneOfVariant>::Variant>,
149    content_buffer: Vec<serde_json::Value>,
150}
151
152impl<T: Tracker> TrackerWrapper for TaggedOneOfTracker<T>
153where
154    T::Target: TrackedOneOfVariant,
155{
156    type Tracker = T;
157}
158
159impl<'de, T> TrackerDeserializeIdentifier<'de> for TaggedOneOfTracker<T>
160where
161    T: Tracker,
162    T::Target: TrackedOneOfVariant + IdentifierFor,
163    <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
164    T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
165{
166    fn deserialize<D>(
167        &mut self,
168        value: &mut Self::Target,
169        identifier: <Self::Target as IdentifierFor>::Identifier,
170        deserializer: D,
171    ) -> Result<(), D::Error>
172    where
173        D: DeserializeContent<'de>,
174    {
175        if identifier == <T::Target as IdentifierFor>::Identifier::TAG {
176            let tag = deserializer.deserialize_seed(IdentifierDeserializer::new())?;
177            match (tag, self.tag_buffer) {
178                (IdentifiedValue::Found(tag), _) if !self.tag_invalid() => {
179                    if let Some(existing_tag) = self.tag_buffer {
180                        if existing_tag != tag {
181                            let error = <D::Error as serde::de::Error>::invalid_value(
182                                Unexpected::Str(tag.name()),
183                                &existing_tag.name(),
184                            );
185                            report_de_error(error)?;
186                        }
187                    } else {
188                        self.tag_buffer = Some(tag);
189                    }
190
191                    let _token = SerdePathToken::replace_field(<T::Target as IdentifierFor>::Identifier::CONTENT.name());
192                    for content in self.content_buffer.drain(..) {
193                        let result: Result<(), D::Error> = T::Target::deserialize(
194                            value,
195                            tag,
196                            &mut self.tracker,
197                            SerdeDeserializer {
198                                deserializer: serde::de::IntoDeserializer::into_deserializer(content),
199                            },
200                        )
201                        .map_err(serde::de::Error::custom);
202
203                        if let Err(e) = result {
204                            report_de_error(e)?;
205                        }
206                    }
207                }
208                (IdentifiedValue::Unknown(v), None) => {
209                    self.set_tag_invalid();
210                    let error = <D::Error as serde::de::Error>::unknown_variant(
211                        v.as_ref(),
212                        <T::Target as TrackedOneOfVariant>::Variant::OPTIONS,
213                    );
214                    report_de_error(error)?;
215                }
216                (IdentifiedValue::Unknown(v), Some(tag)) => {
217                    self.set_tag_invalid();
218                    let error = <D::Error as serde::de::Error>::invalid_value(Unexpected::Str(v.as_ref()), &tag.name());
219                    report_de_error(error)?;
220                }
221                _ => {}
222            }
223        } else if identifier == <T::Target as IdentifierFor>::Identifier::CONTENT {
224            self.set_has_content();
225            if !self.tag_invalid() {
226                if let Some(tag) = self.tag_buffer {
227                    let result: Result<(), D::Error> = T::Target::deserialize(value, tag, &mut self.tracker, deserializer);
228                    if let Err(e) = result {
229                        report_de_error(e)?;
230                    }
231                } else {
232                    self.content_buffer
233                        .push(deserializer.deserialize::<serde_json::Value>().inspect_err(|_| {
234                            set_irrecoverable();
235                        })?);
236                }
237            }
238        } else {
239            report_tracked_error(TrackedError::unknown_field(T::Target::DENY_UNKNOWN_FIELDS))?;
240        }
241
242        Ok(())
243    }
244}
245
246impl<'de, T> TrackerDeserializer<'de> for TaggedOneOfTracker<T>
247where
248    T: Tracker,
249    T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
250    <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
251{
252    fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
253    where
254        D: DeserializeContent<'de>,
255    {
256        deserializer.deserialize_seed(DeserializeHelper { value, tracker: self })
257    }
258}
259
260impl<T> std::ops::Deref for TaggedOneOfTracker<T>
261where
262    T: Tracker,
263    T::Target: TrackedOneOfVariant,
264{
265    type Target = Option<T>;
266
267    fn deref(&self) -> &Self::Target {
268        &self.tracker
269    }
270}
271
272impl<T> std::ops::DerefMut for TaggedOneOfTracker<T>
273where
274    T: Tracker,
275    T::Target: TrackedOneOfVariant,
276{
277    fn deref_mut(&mut self) -> &mut Self::Target {
278        &mut self.tracker
279    }
280}
281
282impl<T> std::fmt::Debug for TaggedOneOfTracker<T>
283where
284    T: Tracker + std::fmt::Debug,
285    T::Target: TrackedOneOfVariant,
286{
287    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
288        f.debug_struct("TaggedOneOfTracker")
289            .field("tracker", &self.tracker)
290            .field("state", &self.state)
291            .field("tag_buffer", &self.tag_buffer.map(|t| t.name()))
292            .field("value_buffer", &self.content_buffer)
293            .finish()
294    }
295}
296
297impl<T> Default for TaggedOneOfTracker<T>
298where
299    T: Tracker,
300    T::Target: TrackedOneOfVariant,
301{
302    fn default() -> Self {
303        Self {
304            tracker: None,
305            state: 0,
306            tag_buffer: None,
307            content_buffer: Vec::new(),
308        }
309    }
310}
311
312impl<T> TaggedOneOfTracker<T>
313where
314    T: Tracker,
315    T::Target: TrackedOneOfVariant,
316{
317    pub fn tag_invalid(&self) -> bool {
318        self.state & TAGGED_ONE_OF_TRACKER_STATE_TAG_INVALID != 0
319    }
320
321    pub fn set_tag_invalid(&mut self) {
322        self.state |= TAGGED_ONE_OF_TRACKER_STATE_TAG_INVALID;
323    }
324
325    pub fn has_content(&self) -> bool {
326        self.state & TAGGED_ONE_OF_TRACKER_STATE_HAS_CONTENT != 0
327    }
328
329    pub fn set_has_content(&mut self) {
330        self.state |= TAGGED_ONE_OF_TRACKER_STATE_HAS_CONTENT;
331    }
332}
333
334impl<T> Tracker for TaggedOneOfTracker<T>
335where
336    T: Tracker,
337    T::Target: TrackedOneOfVariant,
338{
339    type Target = Option<T::Target>;
340
341    fn allow_duplicates(&self) -> bool {
342        self.tracker.as_ref().is_none_or(|t| t.allow_duplicates())
343    }
344}
345
346impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, TaggedOneOfTracker<T>>
347where
348    T: Tracker,
349    T::Target: TrackedOneOfDeserializer<'de, Tracker = TaggedOneOfTracker<T>>,
350    <T::Target as IdentifierFor>::Identifier: TaggedOneOfIdentifier,
351{
352    type Value = ();
353
354    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
355    where
356        D: serde::Deserializer<'de>,
357    {
358        deserializer.deserialize_struct(T::Target::NAME, <T::Target as IdentifierFor>::Identifier::OPTIONS, self)
359    }
360}
361
362#[derive(Debug)]
363pub struct OneOfTracker<T>(pub Option<T>);
364
365impl<T: Tracker> TrackerWrapper for OneOfTracker<T> {
366    type Tracker = T;
367}
368
369impl<T> std::ops::Deref for OneOfTracker<T> {
370    type Target = Option<T>;
371
372    fn deref(&self) -> &Self::Target {
373        &self.0
374    }
375}
376
377impl<T> std::ops::DerefMut for OneOfTracker<T> {
378    fn deref_mut(&mut self) -> &mut Self::Target {
379        &mut self.0
380    }
381}
382
383impl<T> Default for OneOfTracker<T> {
384    fn default() -> Self {
385        Self(None)
386    }
387}
388
389impl<T: Tracker> Tracker for OneOfTracker<T> {
390    type Target = Option<T::Target>;
391
392    fn allow_duplicates(&self) -> bool {
393        self.0.as_ref().is_none_or(|value| value.allow_duplicates())
394    }
395}
396
397impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, OneOfTracker<T>>
398where
399    T: Tracker,
400    T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
401{
402    type Value = ();
403
404    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
405    where
406        D: serde::Deserializer<'de>,
407    {
408        deserializer.deserialize_enum(T::Target::NAME, <T::Target as IdentifierFor>::Identifier::OPTIONS, self)
409    }
410}
411
412impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, OneOfTracker<T>>
413where
414    T: Tracker,
415    T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
416{
417    type Value = ();
418
419    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
420        write!(formatter, "one of")
421    }
422
423    fn visit_enum<A>(self, data: A) -> Result<Self::Value, A::Error>
424    where
425        A: serde::de::EnumAccess<'de>,
426    {
427        let (variant, variant_access) =
428            data.variant_seed(IdentifierDeserializer::<<T::Target as IdentifierFor>::Identifier>::new())?;
429        match variant {
430            IdentifiedValue::Found(variant) => {
431                let _token = SerdePathToken::push_field(variant.name());
432                TrackerDeserializeIdentifier::deserialize(
433                    self.tracker,
434                    self.value,
435                    variant,
436                    VariantAccessDeserializer { de: variant_access },
437                )
438            }
439            IdentifiedValue::Unknown(variant) => {
440                let error = <A::Error as serde::de::Error>::unknown_variant(
441                    variant.as_ref(),
442                    <T::Target as IdentifierFor>::Identifier::OPTIONS,
443                );
444                report_de_error(error)?;
445                variant_access.newtype_variant::<serde::de::IgnoredAny>().inspect_err(|_| {
446                    set_irrecoverable();
447                })?;
448                Ok(())
449            }
450        }
451    }
452}
453
454impl<'de, T> TrackerDeserializeIdentifier<'de> for OneOfTracker<T>
455where
456    T: Tracker,
457    T::Target: TrackedOneOfDeserializer<'de, Tracker = OneOfTracker<T>, Variant = <T::Target as IdentifierFor>::Identifier>,
458{
459    fn deserialize<D>(
460        &mut self,
461        value: &mut Self::Target,
462        identifier: <Self::Target as IdentifierFor>::Identifier,
463        deserializer: D,
464    ) -> Result<(), D::Error>
465    where
466        D: DeserializeContent<'de>,
467    {
468        T::Target::deserialize(value, identifier, self, deserializer)
469    }
470}
471
472struct VariantAccessDeserializer<D> {
473    de: D,
474}
475
476impl<'de, D> DeserializeContent<'de> for VariantAccessDeserializer<D>
477where
478    D: serde::de::VariantAccess<'de>,
479{
480    type Error = D::Error;
481
482    fn deserialize_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
483    where
484        T: serde::de::DeserializeSeed<'de>,
485    {
486        self.de.newtype_variant_seed(seed)
487    }
488}