orinium_browser/platform/network/
core.rs1use super::{HostKey, HttpSender, NetworkConfig, NetworkError, SenderPool};
5
6use http_body_util::{BodyExt, Empty};
7use hyper::{
8 Method, Request, Uri,
9 body::{Bytes, Incoming},
10 client::conn,
11 http::uri::Scheme,
12};
13use hyper_util::rt::TokioIo;
14use rustls::{ClientConfig, RootCertStore};
15use rustls_native_certs::load_native_certs;
16use std::sync::Arc;
17use tokio::{net::TcpStream, runtime::Runtime, task::LocalSet};
18use tokio_rustls::TlsConnector;
19
20pub(super) struct AsyncNetworkCore {
21 local: LocalSet,
22 rt: Runtime,
23 inner: NetworkInner,
24}
25
26impl AsyncNetworkCore {
27 pub fn new() -> Self {
28 let rt = tokio::runtime::Builder::new_current_thread()
29 .enable_all()
30 .build()
31 .expect("failed to build tokio runtime");
32
33 let local = LocalSet::new();
34
35 Self {
36 rt,
37 local,
38 inner: NetworkInner::new(),
39 }
40 }
41
42 pub fn set_network_config(&mut self, config: NetworkConfig) {
43 self.inner.set_network_config(config)
44 }
45
46 pub fn fetch_blocking(&self, url: &str) -> Result<Response, NetworkError> {
48 self.local
50 .block_on(&self.rt, async { self.inner.fetch_url(url).await })
51 }
52}
53
54pub struct Response {
56 pub url: String,
57 pub status: hyper::StatusCode,
58 pub reason_phrase: String,
59 pub headers: Vec<(String, String)>,
60 pub body: Vec<u8>,
61}
62
63pub(super) struct NetworkInner {
64 sender_pool: Arc<std::sync::RwLock<SenderPool>>,
65 tls_config: Arc<ClientConfig>,
66 network_config: Arc<NetworkConfig>,
67}
68
69impl NetworkInner {
70 pub fn new() -> Self {
71 Self {
72 sender_pool: Arc::new(std::sync::RwLock::new(SenderPool::new())),
73 tls_config: Arc::new(Self::build_tls_config()),
74 network_config: Arc::new(NetworkConfig::default()),
75 }
76 }
77
78 pub fn set_network_config(&mut self, confing: NetworkConfig) {
79 self.network_config = Arc::new(confing)
80 }
81
82 fn build_tls_config() -> ClientConfig {
83 let mut roots = RootCertStore::empty();
84 let result = load_native_certs();
85
86 for cert in result.certs {
87 let _ = roots.add(cert);
88 }
89
90 ClientConfig::builder()
91 .with_root_certificates(roots)
92 .with_no_client_auth()
93 }
94
95 pub async fn fetch_url(&self, url: &str) -> Result<Response, NetworkError> {
96 let mut current: Uri = url.parse().map_err(|_| NetworkError::InvalidUri)?;
97 let mut redirects = 0usize;
98
99 loop {
100 let resp = self.send_request(¤t).await?;
101
102 if self.network_config.follow_redirects && resp.status.is_redirection() {
103 if redirects >= 10 {
104 return Err(NetworkError::TooManyRedirects);
105 }
106
107 if let Some(loc) = resp
108 .headers
109 .iter()
110 .find(|(k, _)| k.eq_ignore_ascii_case("location"))
111 .map(|(_, v)| v)
112 {
113 current = resolve_redirect(¤t, loc)?;
114 redirects += 1;
115 continue;
116 }
117 }
118
119 return Ok(resp);
120 }
121 }
122
123 async fn send_request(&self, uri: &Uri) -> Result<Response, NetworkError> {
124 let host = uri.host().ok_or(NetworkError::MissingHost)?;
125 let scheme = uri.scheme().unwrap_or(&Scheme::HTTP);
126 let port = uri
127 .port_u16()
128 .unwrap_or(if scheme == &Scheme::HTTPS { 443 } else { 80 });
129
130 let key = HostKey {
131 scheme: scheme.clone(),
132 host: host.to_string(),
133 port,
134 };
135
136 let mut sender = self.get_or_create_sender(&key).await?;
137
138 let req = Request::builder()
139 .method(Method::GET)
140 .uri(uri.path_and_query().map(|p| p.as_str()).unwrap_or("/"))
141 .header("Host", host)
142 .header("User-Agent", self.network_config.user_agent.as_str())
143 .body(Empty::<Bytes>::new())
144 .map_err(|_| NetworkError::HttpRequestFailed)?;
145
146 let mut res = match &mut sender {
147 HttpSender::Http1(s) => s
148 .send_request(req)
149 .await
150 .map_err(|_| NetworkError::HttpRequestFailed)?,
151 _ => {
152 return Err(NetworkError::UnsupportedHttpVersion);
153 }
154 };
155
156 let response = Self::collect_response(uri.to_string(), &mut res).await?;
157
158 self.sender_pool
159 .write()
160 .unwrap()
161 .add_connection(key, sender);
162
163 Ok(response)
164 }
165
166 async fn collect_response(
167 url: String,
168 res: &mut hyper::Response<Incoming>,
169 ) -> Result<Response, NetworkError> {
170 let status = res.status();
171 let reason_phrase = status.canonical_reason().unwrap_or("").to_string();
172
173 let headers = res
174 .headers()
175 .iter()
176 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
177 .collect();
178
179 let mut body = Vec::new();
180 while let Some(frame) = res.frame().await {
181 let frame = frame.map_err(|_| NetworkError::HttpResponseFailed)?;
182 if let Some(chunk) = frame.data_ref() {
183 body.extend_from_slice(chunk);
184 }
185 }
186
187 Ok(Response {
188 url,
189 status,
190 reason_phrase,
191 headers,
192 body,
193 })
194 }
195
196 async fn get_or_create_sender(&self, key: &HostKey) -> Result<HttpSender, NetworkError> {
197 if let Some(s) = self.sender_pool.write().unwrap().get_connection(key) {
198 return Ok(s);
199 }
200
201 self.create_connection(key).await
202 }
203
204 async fn create_connection(&self, key: &HostKey) -> Result<HttpSender, NetworkError> {
205 let addr = format!("{}:{}", key.host, key.port);
206 let stream = TcpStream::connect(addr)
207 .await
208 .map_err(|_| NetworkError::ConnectionFailed)?;
209
210 if key.scheme == Scheme::HTTPS {
211 let tls = TlsConnector::from(self.tls_config.clone());
212 let key = key.clone();
213 let domain = rustls::pki_types::ServerName::try_from(key.host.clone())
214 .map_err(|_| NetworkError::InvalidDnsName)?;
215
216 let stream = tls
217 .connect(domain, stream)
218 .await
219 .map_err(|_| NetworkError::TlsFailed)?;
220
221 let (sender, conn) = conn::http1::handshake(TokioIo::new(stream))
222 .await
223 .map_err(|_| NetworkError::HttpHandshakeFailed)?;
224
225 self.spawn_connection_task(conn, key);
226 Ok(HttpSender::Http1(sender))
227 } else {
228 let (sender, conn) = conn::http1::handshake(TokioIo::new(stream))
229 .await
230 .map_err(|_| NetworkError::HttpHandshakeFailed)?;
231
232 self.spawn_connection_task(conn, key.clone());
233 Ok(HttpSender::Http1(sender))
234 }
235 }
236
237 fn spawn_connection_task(
238 &self,
239 conn: conn::http1::Connection<
240 TokioIo<impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static>,
241 Empty<Bytes>,
242 >,
243 key: HostKey,
244 ) {
245 let pool = self.sender_pool.clone();
246 tokio::task::spawn_local(async move {
247 let _ = conn.await;
248 pool.write().unwrap().remove_connection(&key);
249 });
250 }
251}
252
253fn resolve_redirect(base: &Uri, location: &str) -> Result<Uri, NetworkError> {
254 if location.starts_with("http://") || location.starts_with("https://") {
255 return location.parse().map_err(|_| NetworkError::InvalidUri);
256 }
257
258 let scheme = base.scheme_str().unwrap_or("https");
259 let authority = base.authority().ok_or(NetworkError::InvalidUri)?;
260
261 let next = if location.starts_with("//") {
262 format!("{scheme}:{location}")
263 } else if location.starts_with('/') {
264 format!("{scheme}://{}{location}", authority)
265 } else {
266 let base_path = base.path();
267 let prefix = base_path.rsplit_once('/').map(|x| x.0).unwrap_or("");
268 format!("{scheme}://{}{prefix}/{location}", authority)
269 };
270
271 next.parse().map_err(|_| NetworkError::InvalidUri)
272}