rocket_http/listener.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
use std::fmt;
use std::future::Future;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::sync::Arc;
use log::warn;
use tokio::time::Sleep;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use hyper::server::accept::Accept;
use state::InitCell;
pub use tokio::net::TcpListener;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
// NOTE: `rustls::Certificate` is exactly isomorphic to `CertificateData`.
#[doc(inline)]
#[cfg(feature = "tls")]
pub use rustls::Certificate as CertificateData;
/// A thin wrapper over raw, DER-encoded X.509 client certificate data.
#[cfg(not(feature = "tls"))]
#[derive(Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct CertificateData(pub Vec<u8>);
/// A collection of raw certificate data.
#[derive(Clone, Default)]
pub struct Certificates(Arc<InitCell<Vec<CertificateData>>>);
impl From<Vec<CertificateData>> for Certificates {
fn from(value: Vec<CertificateData>) -> Self {
Certificates(Arc::new(value.into()))
}
}
impl Certificates {
/// Set the the raw certificate chain data. Only the first call actually
/// sets the data; the remaining do nothing.
#[cfg(feature = "tls")]
pub(crate) fn set(&self, data: Vec<CertificateData>) {
self.0.set(data);
}
/// Returns the raw certificate chain data, if any is available.
pub fn chain_data(&self) -> Option<&[CertificateData]> {
self.0.try_get().map(|v| v.as_slice())
}
}
// TODO.async: 'Listener' and 'Connection' provide common enough functionality
// that they could be introduced in upstream libraries.
/// A 'Listener' yields incoming connections
pub trait Listener {
/// The connection type returned by this listener.
type Connection: Connection;
/// Return the actual address this listener bound to.
fn local_addr(&self) -> Option<SocketAddr>;
/// Try to accept an incoming Connection if ready. This should only return
/// an `Err` when a fatal problem occurs as Hyper kills the server on `Err`.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>>;
}
/// A 'Connection' represents an open connection to a client
pub trait Connection: AsyncRead + AsyncWrite {
/// The remote address, i.e. the client's socket address, if it is known.
fn peer_address(&self) -> Option<SocketAddr>;
/// Requests that the connection not delay reading or writing data as much
/// as possible. For connections backed by TCP, this corresponds to setting
/// `TCP_NODELAY`.
fn enable_nodelay(&self) -> io::Result<()>;
/// DER-encoded X.509 certificate chain presented by the client, if any.
///
/// The certificate order must be as it appears in the TLS protocol: the
/// first certificate relates to the peer, the second certifies the first,
/// the third certifies the second, and so on.
///
/// Defaults to an empty vector to indicate that no certificates were
/// presented.
fn peer_certificates(&self) -> Option<Certificates> { None }
}
pin_project_lite::pin_project! {
/// This is a generic version of hyper's AddrIncoming that is intended to be
/// usable with listeners other than a plain TCP stream, e.g. TLS and/or Unix
/// sockets. It does so by bridging the `Listener` trait to what hyper wants (an
/// Accept). This type is internal to Rocket.
#[must_use = "streams do nothing unless polled"]
pub struct Incoming<L> {
sleep_on_errors: Option<Duration>,
nodelay: bool,
#[pin]
pending_error_delay: Option<Sleep>,
#[pin]
listener: L,
}
}
impl<L: Listener> Incoming<L> {
/// Construct an `Incoming` from an existing `Listener`.
pub fn new(listener: L) -> Self {
Self {
listener,
sleep_on_errors: Some(Duration::from_millis(250)),
pending_error_delay: None,
nodelay: false,
}
}
/// Set whether and how long to sleep on accept errors.
///
/// A possible scenario is that the process has hit the max open files
/// allowed, and so trying to accept a new connection will fail with
/// `EMFILE`. In some cases, it's preferable to just wait for some time, if
/// the application will likely close some files (or connections), and try
/// to accept the connection again. If this option is `true`, the error
/// will be logged at the `error` level, since it is still a big deal,
/// and then the listener will sleep for 1 second.
///
/// In other cases, hitting the max open files should be treat similarly
/// to being out-of-memory, and simply error (and shutdown). Setting
/// this option to `None` will allow that.
///
/// Default is 1 second.
pub fn sleep_on_errors(mut self, val: Option<Duration>) -> Self {
self.sleep_on_errors = val;
self
}
/// Set whether to request no delay on all incoming connections. The default
/// is `false`. See [`Connection::enable_nodelay()`] for details.
pub fn nodelay(mut self, nodelay: bool) -> Self {
self.nodelay = nodelay;
self
}
fn poll_accept_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<L::Connection>> {
/// This function defines per-connection errors: errors that affect only
/// a single connection's accept() and don't imply anything about the
/// success probability of the next accept(). Thus, we can attempt to
/// `accept()` another connection immediately. All other errors will
/// incur a delay before the next `accept()` is performed. The delay is
/// useful to handle resource exhaustion errors like ENFILE and EMFILE.
/// Otherwise, could enter into tight loop.
fn is_connection_error(e: &io::Error) -> bool {
matches!(e.kind(),
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::ConnectionReset)
}
let mut this = self.project();
loop {
// Check if a previous sleep timer is active, set on I/O errors.
if let Some(delay) = this.pending_error_delay.as_mut().as_pin_mut() {
futures::ready!(delay.poll(cx));
}
this.pending_error_delay.set(None);
match futures::ready!(this.listener.as_mut().poll_accept(cx)) {
Ok(stream) => {
if *this.nodelay {
if let Err(e) = stream.enable_nodelay() {
warn!("failed to enable NODELAY: {}", e);
}
}
return Poll::Ready(Ok(stream));
},
Err(e) => {
if is_connection_error(&e) {
warn!("single connection accept error {}; accepting next now", e);
} else if let Some(duration) = this.sleep_on_errors {
// We might be able to recover. Try again in a bit.
warn!("accept error {}; recovery attempt in {}ms", e, duration.as_millis());
this.pending_error_delay.set(Some(tokio::time::sleep(*duration)));
} else {
return Poll::Ready(Err(e));
}
},
}
}
}
}
impl<L: Listener> Accept for Incoming<L> {
type Conn = L::Connection;
type Error = io::Error;
#[inline]
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Option<io::Result<Self::Conn>>> {
self.poll_accept_next(cx).map(Some)
}
}
impl<L: fmt::Debug> fmt::Debug for Incoming<L> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Incoming")
.field("listener", &self.listener)
.finish()
}
}
impl Listener for TcpListener {
type Connection = TcpStream;
#[inline]
fn local_addr(&self) -> Option<SocketAddr> {
self.local_addr().ok()
}
#[inline]
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<io::Result<Self::Connection>> {
(*self).poll_accept(cx).map_ok(|(stream, _addr)| stream)
}
}
impl Connection for TcpStream {
#[inline]
fn peer_address(&self) -> Option<SocketAddr> {
self.peer_addr().ok()
}
#[inline]
fn enable_nodelay(&self) -> io::Result<()> {
self.set_nodelay(true)
}
}