scuffle_http/backend/hyper/
mod.rs

1//! Hyper backend.
2use std::fmt::Debug;
3use std::net::SocketAddr;
4
5use scuffle_context::ContextFutExt;
6#[cfg(feature = "tracing")]
7use tracing::Instrument;
8
9use crate::error::HttpError;
10use crate::service::{HttpService, HttpServiceFactory};
11
12mod handler;
13mod stream;
14mod utils;
15
16/// A backend that handles incoming HTTP connections using a hyper backend.
17///
18/// This is used internally by the [`HttpServer`](crate::server::HttpServer) but can be used directly if preferred.
19///
20/// Call [`run`](HyperBackend::run) to start the server.
21#[derive(Debug, Clone, bon::Builder)]
22pub struct HyperBackend<F> {
23    /// The [`scuffle_context::Context`] this server will live by.
24    #[builder(default = scuffle_context::Context::global())]
25    ctx: scuffle_context::Context,
26    /// The number of worker tasks to spawn for each server backend.
27    #[builder(default = 1)]
28    worker_tasks: usize,
29    /// The service factory that will be used to create new services.
30    service_factory: F,
31    /// The address to bind to.
32    ///
33    /// Use `[::]` for a dual-stack listener.
34    /// For example, use `[::]:80` to bind to port 80 on both IPv4 and IPv6.
35    bind: SocketAddr,
36    /// rustls config.
37    ///
38    /// Use this field to set the server into TLS mode.
39    /// It will only accept TLS connections when this is set.
40    #[cfg(feature = "tls-rustls")]
41    rustls_config: Option<rustls::ServerConfig>,
42    /// Enable HTTP/1.1.
43    #[cfg(feature = "http1")]
44    #[builder(default = true)]
45    http1_enabled: bool,
46    /// Enable HTTP/2.
47    #[cfg(feature = "http2")]
48    #[builder(default = true)]
49    http2_enabled: bool,
50}
51
52impl<F> HyperBackend<F>
53where
54    F: HttpServiceFactory + Clone + Send + 'static,
55    F::Error: std::error::Error + Send,
56    F::Service: Clone + Send + 'static,
57    <F::Service as HttpService>::Error: std::error::Error + Send + Sync,
58    <F::Service as HttpService>::ResBody: Send,
59    <<F::Service as HttpService>::ResBody as http_body::Body>::Data: Send,
60    <<F::Service as HttpService>::ResBody as http_body::Body>::Error: std::error::Error + Send + Sync,
61{
62    /// Run the HTTP server
63    ///
64    /// This function will bind to the address specified in `bind`, listen for incoming connections and handle requests.
65    #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(bind = %self.bind)))]
66    #[allow(unused_mut)] // allow the unused `mut self`
67    pub async fn run(mut self) -> Result<(), HttpError<F>> {
68        #[cfg(feature = "tracing")]
69        tracing::debug!("starting server");
70
71        // reset to 0 because everything explodes if it's not
72        // https://github.com/hyperium/hyper/issues/3841
73        #[cfg(feature = "tls-rustls")]
74        if let Some(rustls_config) = self.rustls_config.as_mut() {
75            rustls_config.max_early_data_size = 0;
76        }
77
78        // We have to create an std listener first because the tokio listener isn't clonable
79        let listener = tokio::net::TcpListener::bind(self.bind).await?.into_std()?;
80
81        #[cfg(feature = "tls-rustls")]
82        let tls_acceptor = self
83            .rustls_config
84            .map(|c| tokio_rustls::TlsAcceptor::from(std::sync::Arc::new(c)));
85
86        // Create a child context for the workers so we can shut them down if one of them fails without shutting down the main context
87        let (worker_ctx, worker_handler) = self.ctx.new_child();
88
89        let workers = (0..self.worker_tasks)
90            .map(|_n| {
91                let service_factory = self.service_factory.clone();
92                let ctx = worker_ctx.clone();
93                let std_listener = listener.try_clone()?;
94                let listener = tokio::net::TcpListener::from_std(std_listener)?;
95                #[cfg(feature = "tls-rustls")]
96                let tls_acceptor = tls_acceptor.clone();
97
98                let worker_fut = async move {
99                    loop {
100                        #[cfg(feature = "tracing")]
101                        tracing::trace!("waiting for connections");
102
103                        let (mut stream, addr) = match listener.accept().with_context(ctx.clone()).await {
104                            Some(Ok((tcp_stream, addr))) => (stream::Stream::Tcp(tcp_stream), addr),
105                            Some(Err(e)) if utils::is_fatal_tcp_error(&e) => {
106                                #[cfg(feature = "tracing")]
107                                tracing::error!(err = %e, "failed to accept tcp connection");
108                                return Err(HttpError::<F>::from(e));
109                            }
110                            Some(Err(_)) => continue,
111                            None => {
112                                #[cfg(feature = "tracing")]
113                                tracing::trace!("context done, stopping listener");
114                                break;
115                            }
116                        };
117
118                        #[cfg(feature = "tracing")]
119                        tracing::trace!(addr = %addr, "accepted tcp connection");
120
121                        let ctx = ctx.clone();
122                        #[cfg(feature = "tls-rustls")]
123                        let tls_acceptor = tls_acceptor.clone();
124                        let mut service_factory = service_factory.clone();
125
126                        let connection_fut = async move {
127                            // Perform the TLS handshake if the acceptor is set
128                            #[cfg(feature = "tls-rustls")]
129                            if let Some(tls_acceptor) = tls_acceptor {
130                                #[cfg(feature = "tracing")]
131                                tracing::trace!("accepting tls connection");
132
133                                stream = match stream.try_accept_tls(&tls_acceptor).with_context(&ctx).await {
134                                    Some(Ok(stream)) => stream,
135                                    Some(Err(_err)) => {
136                                        #[cfg(feature = "tracing")]
137                                        tracing::warn!(err = %_err, "failed to accept tls connection");
138                                        return;
139                                    }
140                                    None => {
141                                        #[cfg(feature = "tracing")]
142                                        tracing::trace!("context done, stopping tls acceptor");
143                                        return;
144                                    }
145                                };
146
147                                #[cfg(feature = "tracing")]
148                                tracing::trace!("accepted tls connection");
149                            }
150
151                            // make a new service
152                            let http_service = match service_factory.new_service(addr).await {
153                                Ok(service) => service,
154                                Err(_e) => {
155                                    #[cfg(feature = "tracing")]
156                                    tracing::warn!(err = %_e, "failed to create service");
157                                    return;
158                                }
159                            };
160
161                            #[cfg(feature = "tracing")]
162                            tracing::trace!("handling connection");
163
164                            #[cfg(feature = "http1")]
165                            let http1 = self.http1_enabled;
166                            #[cfg(not(feature = "http1"))]
167                            let http1 = false;
168
169                            #[cfg(feature = "http2")]
170                            let http2 = self.http2_enabled;
171                            #[cfg(not(feature = "http2"))]
172                            let http2 = false;
173
174                            let _res = handler::handle_connection::<F, _, _>(ctx, http_service, stream, http1, http2).await;
175
176                            #[cfg(feature = "tracing")]
177                            if let Err(e) = _res {
178                                tracing::warn!(err = %e, "error handling connection");
179                            }
180
181                            #[cfg(feature = "tracing")]
182                            tracing::trace!("connection closed");
183                        };
184
185                        #[cfg(feature = "tracing")]
186                        let connection_fut = connection_fut.instrument(tracing::trace_span!("connection", addr = %addr));
187
188                        tokio::spawn(connection_fut);
189                    }
190
191                    #[cfg(feature = "tracing")]
192                    tracing::trace!("listener closed");
193
194                    Ok(())
195                };
196
197                #[cfg(feature = "tracing")]
198                let worker_fut = worker_fut.instrument(tracing::trace_span!("worker", n = _n));
199
200                Ok(tokio::spawn(worker_fut))
201            })
202            .collect::<std::io::Result<Vec<_>>>()?;
203
204        match futures::future::try_join_all(workers).await {
205            Ok(res) => {
206                for r in res {
207                    if let Err(e) = r {
208                        drop(worker_ctx);
209                        worker_handler.shutdown().await;
210                        return Err(e);
211                    }
212                }
213            }
214            Err(_e) => {
215                #[cfg(feature = "tracing")]
216                tracing::error!(err = %_e, "error running workers");
217            }
218        }
219
220        drop(worker_ctx);
221        worker_handler.shutdown().await;
222
223        #[cfg(feature = "tracing")]
224        tracing::debug!("all workers finished");
225
226        Ok(())
227    }
228}