orinium_browser/platform/network/
core.rs

1//! ネットワークコア
2//! HTTP通信とレスポンス処理を担当する。
3
4use super::{HostKey, HttpSender, NetworkConfig, NetworkError, SenderPool};
5
6use http_body_util::{BodyExt, Empty};
7use hyper::{
8    Method, Request, Uri,
9    body::{Bytes, Incoming},
10    client::conn,
11    http::uri::Scheme,
12};
13use hyper_util::rt::TokioIo;
14use rustls::{ClientConfig, RootCertStore};
15use rustls_native_certs::load_native_certs;
16use std::sync::Arc;
17use tokio::{net::TcpStream, runtime::Runtime, task::LocalSet};
18use tokio_rustls::TlsConnector;
19
20pub(super) struct AsyncNetworkCore {
21    local: LocalSet,
22    rt: Runtime,
23    inner: NetworkInner,
24}
25
26impl AsyncNetworkCore {
27    pub fn new() -> Self {
28        let rt = tokio::runtime::Builder::new_current_thread()
29            .enable_all()
30            .build()
31            .expect("failed to build tokio runtime");
32
33        let local = LocalSet::new();
34
35        Self {
36            rt,
37            local,
38            inner: NetworkInner::new(),
39        }
40    }
41
42    pub fn set_network_config(&mut self, config: NetworkConfig) {
43        self.inner.set_network_config(config)
44    }
45
46    /// UI スレッドなどから呼ばれる blocking API
47    pub fn fetch_blocking(&self, url: &str) -> Result<Response, NetworkError> {
48        // network スレッド内で完結させる
49        self.local
50            .block_on(&self.rt, async { self.inner.fetch_url(url).await })
51    }
52}
53
54/// HTTP response
55pub struct Response {
56    pub url: String,
57    pub status: hyper::StatusCode,
58    pub reason_phrase: String,
59    pub headers: Vec<(String, String)>,
60    pub body: Vec<u8>,
61}
62
63pub(super) struct NetworkInner {
64    sender_pool: Arc<std::sync::RwLock<SenderPool>>,
65    tls_config: Arc<ClientConfig>,
66    network_config: Arc<NetworkConfig>,
67}
68
69impl NetworkInner {
70    pub fn new() -> Self {
71        Self {
72            sender_pool: Arc::new(std::sync::RwLock::new(SenderPool::new())),
73            tls_config: Arc::new(Self::build_tls_config()),
74            network_config: Arc::new(NetworkConfig::default()),
75        }
76    }
77
78    pub fn set_network_config(&mut self, confing: NetworkConfig) {
79        self.network_config = Arc::new(confing)
80    }
81
82    fn build_tls_config() -> ClientConfig {
83        let mut roots = RootCertStore::empty();
84        let result = load_native_certs();
85
86        for cert in result.certs {
87            let _ = roots.add(cert);
88        }
89
90        ClientConfig::builder()
91            .with_root_certificates(roots)
92            .with_no_client_auth()
93    }
94
95    pub async fn fetch_url(&self, url: &str) -> Result<Response, NetworkError> {
96        let mut current: Uri = url.parse().map_err(|_| NetworkError::InvalidUri)?;
97        let mut redirects = 0usize;
98
99        loop {
100            let resp = self.send_request(&current).await?;
101
102            if self.network_config.follow_redirects && resp.status.is_redirection() {
103                if redirects >= 10 {
104                    return Err(NetworkError::TooManyRedirects);
105                }
106
107                if let Some(loc) = resp
108                    .headers
109                    .iter()
110                    .find(|(k, _)| k.eq_ignore_ascii_case("location"))
111                    .map(|(_, v)| v)
112                {
113                    current = resolve_redirect(&current, loc)?;
114                    redirects += 1;
115                    continue;
116                }
117            }
118
119            return Ok(resp);
120        }
121    }
122
123    async fn send_request(&self, uri: &Uri) -> Result<Response, NetworkError> {
124        let host = uri.host().ok_or(NetworkError::MissingHost)?;
125        let scheme = uri.scheme().unwrap_or(&Scheme::HTTP);
126        let port = uri
127            .port_u16()
128            .unwrap_or(if scheme == &Scheme::HTTPS { 443 } else { 80 });
129
130        let key = HostKey {
131            scheme: scheme.clone(),
132            host: host.to_string(),
133            port,
134        };
135
136        let mut sender = self.get_or_create_sender(&key).await?;
137
138        let req = Request::builder()
139            .method(Method::GET)
140            .uri(uri.path_and_query().map(|p| p.as_str()).unwrap_or("/"))
141            .header("Host", host)
142            .header("User-Agent", self.network_config.user_agent.as_str())
143            .body(Empty::<Bytes>::new())
144            .map_err(|_| NetworkError::HttpRequestFailed)?;
145
146        let mut res = match &mut sender {
147            HttpSender::Http1(s) => s
148                .send_request(req)
149                .await
150                .map_err(|_| NetworkError::HttpRequestFailed)?,
151            _ => {
152                return Err(NetworkError::UnsupportedHttpVersion);
153            }
154        };
155
156        let response = Self::collect_response(uri.to_string(), &mut res).await?;
157
158        self.sender_pool
159            .write()
160            .unwrap()
161            .add_connection(key, sender);
162
163        Ok(response)
164    }
165
166    async fn collect_response(
167        url: String,
168        res: &mut hyper::Response<Incoming>,
169    ) -> Result<Response, NetworkError> {
170        let status = res.status();
171        let reason_phrase = status.canonical_reason().unwrap_or("").to_string();
172
173        let headers = res
174            .headers()
175            .iter()
176            .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
177            .collect();
178
179        let mut body = Vec::new();
180        while let Some(frame) = res.frame().await {
181            let frame = frame.map_err(|_| NetworkError::HttpResponseFailed)?;
182            if let Some(chunk) = frame.data_ref() {
183                body.extend_from_slice(chunk);
184            }
185        }
186
187        Ok(Response {
188            url,
189            status,
190            reason_phrase,
191            headers,
192            body,
193        })
194    }
195
196    async fn get_or_create_sender(&self, key: &HostKey) -> Result<HttpSender, NetworkError> {
197        if let Some(s) = self.sender_pool.write().unwrap().get_connection(key) {
198            return Ok(s);
199        }
200
201        self.create_connection(key).await
202    }
203
204    async fn create_connection(&self, key: &HostKey) -> Result<HttpSender, NetworkError> {
205        let addr = format!("{}:{}", key.host, key.port);
206        let stream = TcpStream::connect(addr)
207            .await
208            .map_err(|_| NetworkError::ConnectionFailed)?;
209
210        if key.scheme == Scheme::HTTPS {
211            let tls = TlsConnector::from(self.tls_config.clone());
212            let key = key.clone();
213            let domain = rustls::pki_types::ServerName::try_from(key.host.clone())
214                .map_err(|_| NetworkError::InvalidDnsName)?;
215
216            let stream = tls
217                .connect(domain, stream)
218                .await
219                .map_err(|_| NetworkError::TlsFailed)?;
220
221            let (sender, conn) = conn::http1::handshake(TokioIo::new(stream))
222                .await
223                .map_err(|_| NetworkError::HttpHandshakeFailed)?;
224
225            self.spawn_connection_task(conn, key);
226            Ok(HttpSender::Http1(sender))
227        } else {
228            let (sender, conn) = conn::http1::handshake(TokioIo::new(stream))
229                .await
230                .map_err(|_| NetworkError::HttpHandshakeFailed)?;
231
232            self.spawn_connection_task(conn, key.clone());
233            Ok(HttpSender::Http1(sender))
234        }
235    }
236
237    fn spawn_connection_task(
238        &self,
239        conn: conn::http1::Connection<
240            TokioIo<impl tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + 'static>,
241            Empty<Bytes>,
242        >,
243        key: HostKey,
244    ) {
245        let pool = self.sender_pool.clone();
246        tokio::task::spawn_local(async move {
247            let _ = conn.await;
248            pool.write().unwrap().remove_connection(&key);
249        });
250    }
251}
252
253fn resolve_redirect(base: &Uri, location: &str) -> Result<Uri, NetworkError> {
254    if location.starts_with("http://") || location.starts_with("https://") {
255        return location.parse().map_err(|_| NetworkError::InvalidUri);
256    }
257
258    let scheme = base.scheme_str().unwrap_or("https");
259    let authority = base.authority().ok_or(NetworkError::InvalidUri)?;
260
261    let next = if location.starts_with("//") {
262        format!("{scheme}:{location}")
263    } else if location.starts_with('/') {
264        format!("{scheme}://{}{location}", authority)
265    } else {
266        let base_path = base.path();
267        let prefix = base_path.rsplit_once('/').map(|x| x.0).unwrap_or("");
268        format!("{scheme}://{}{prefix}/{location}", authority)
269    };
270
271    next.parse().map_err(|_| NetworkError::InvalidUri)
272}