scuffle_http/backend/h3/
body.rs1use std::future::Future;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use bytes::{Buf, Bytes};
7use h3::server::RequestStream;
8
9#[derive(thiserror::Error, Debug)]
11pub enum H3BodyError {
12 #[error("h3 error: {0}")]
16 StreamError(#[from] h3::error::StreamError),
17 #[error("unexpected data after trailers")]
19 DataAfterTrailers,
20 #[error("the given buffer size hint was exceeded")]
22 BufferExceeded,
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26enum State {
27 Data(Option<u64>),
28 Trailers,
29 Done,
30}
31
32pub struct QuicIncomingBody<S> {
36 stream: RequestStream<S, Bytes>,
37 state: State,
38}
39
40impl<S> QuicIncomingBody<S> {
41 pub fn new(stream: RequestStream<S, Bytes>, size_hint: Option<u64>) -> Self {
43 Self {
44 stream,
45 state: State::Data(size_hint),
46 }
47 }
48}
49
50impl<S: h3::quic::RecvStream> http_body::Body for QuicIncomingBody<S> {
51 type Data = Bytes;
52 type Error = H3BodyError;
53
54 fn poll_frame(
55 mut self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
58 let QuicIncomingBody { stream, state } = self.as_mut().get_mut();
59
60 if *state == State::Done {
61 return Poll::Ready(None);
62 }
63
64 if let State::Data(remaining) = state {
65 match stream.poll_recv_data(cx) {
66 Poll::Ready(Ok(Some(mut buf))) => {
67 let buf_size = buf.remaining() as u64;
68
69 if let Some(remaining) = remaining {
70 if buf_size > *remaining {
71 *state = State::Done;
72 return Poll::Ready(Some(Err(H3BodyError::BufferExceeded)));
73 }
74
75 *remaining -= buf_size;
76 }
77
78 return Poll::Ready(Some(Ok(http_body::Frame::data(buf.copy_to_bytes(buf_size as usize)))));
79 }
80 Poll::Ready(Ok(None)) => {
81 *state = State::Trailers;
82 }
83 Poll::Ready(Err(err)) => {
84 *state = State::Done;
85 return Poll::Ready(Some(Err(err.into())));
86 }
87 Poll::Pending => {
88 return Poll::Pending;
89 }
90 }
91 }
92
93 let resp = match stream.poll_recv_data(cx) {
99 Poll::Ready(Ok(None)) => match std::pin::pin!(stream.recv_trailers()).poll(cx) {
100 Poll::Ready(Ok(Some(trailers))) => Poll::Ready(Some(Ok(http_body::Frame::trailers(trailers)))),
101 Poll::Pending => {
103 #[cfg(feature = "tracing")]
104 tracing::warn!("recv_trailers is pending");
105 Poll::Ready(None)
106 }
107 Poll::Ready(Ok(None)) => Poll::Ready(None),
108 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err.into()))),
109 },
110 Poll::Ready(Ok(Some(_))) => Poll::Ready(Some(Err(H3BodyError::DataAfterTrailers))),
112 Poll::Ready(Err(err)) => Poll::Ready(Some(Err(err.into()))),
113 Poll::Pending => return Poll::Pending,
114 };
115
116 *state = State::Done;
117
118 resp
119 }
120
121 fn size_hint(&self) -> http_body::SizeHint {
122 match self.state {
123 State::Data(Some(remaining)) => http_body::SizeHint::with_exact(remaining),
124 State::Data(None) => http_body::SizeHint::default(),
125 State::Trailers | State::Done => http_body::SizeHint::with_exact(0),
126 }
127 }
128
129 fn is_end_stream(&self) -> bool {
130 match self.state {
131 State::Data(Some(0)) | State::Trailers | State::Done => true,
132 State::Data(_) => false,
133 }
134 }
135}