summaryrefslogtreecommitdiff
path: root/src/api/appservice_server.rs
blob: 082a1bc2599b9c3143e15717009f416e1904859d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
use crate::{services, utils, Error, Result};
use bytes::BytesMut;
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use std::{fmt::Debug, mem, time::Duration};
use tracing::warn;

#[tracing::instrument(skip(request))]
pub(crate) async fn send_request<T: OutgoingRequest>(
    registration: serde_yaml::Value,
    request: T,
) -> Result<T::IncomingResponse>
where
    T: Debug,
{
    let destination = registration.get("url").unwrap().as_str().unwrap();
    let hs_token = registration.get("hs_token").unwrap().as_str().unwrap();

    let mut http_request = request
        .try_into_http_request::<BytesMut>(
            destination,
            SendAccessToken::IfRequired(hs_token),
            &[MatrixVersion::V1_0],
        )
        .unwrap()
        .map(|body| body.freeze());

    let mut parts = http_request.uri().clone().into_parts();
    let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
    let symbol = if old_path_and_query.contains('?') {
        "&"
    } else {
        "?"
    };

    parts.path_and_query = Some(
        (old_path_and_query + symbol + "access_token=" + hs_token)
            .parse()
            .unwrap(),
    );
    *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");

    let mut reqwest_request = reqwest::Request::try_from(http_request)
        .expect("all http requests are valid reqwest requests");

    *reqwest_request.timeout_mut() = Some(Duration::from_secs(30));

    let url = reqwest_request.url().clone();
    let mut response = match services()
        .globals
        .default_client()
        .execute(reqwest_request)
        .await
    {
        Ok(r) => r,
        Err(e) => {
            warn!(
                "Could not send request to appservice {:?} at {}: {}",
                registration.get("id"),
                destination,
                e
            );
            return Err(e.into());
        }
    };

    // reqwest::Response -> http::Response conversion
    let status = response.status();
    let mut http_response_builder = http::Response::builder()
        .status(status)
        .version(response.version());
    mem::swap(
        response.headers_mut(),
        http_response_builder
            .headers_mut()
            .expect("http::response::Builder is usable"),
    );

    let body = response.bytes().await.unwrap_or_else(|e| {
        warn!("server error: {}", e);
        Vec::new().into()
    }); // TODO: handle timeout

    if status != 200 {
        warn!(
            "Appservice returned bad response {} {}\n{}\n{:?}",
            destination,
            status,
            url,
            utils::string_from_bytes(&body)
        );
    }

    let response = T::IncomingResponse::try_from_http_response(
        http_response_builder
            .body(body)
            .expect("reqwest body is valid http body"),
    );
    response.map_err(|_| {
        warn!(
            "Appservice returned invalid response bytes {}\n{}",
            destination, url
        );
        Error::BadServerResponse("Server returned bad response.")
    })
}