summaryrefslogtreecommitdiff
path: root/src/api/client_server/search.rs
blob: 51255d5a121001f86d7bc2f410367f3cdf777a16 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
use crate::{services, Error, Result, Ruma};
use ruma::api::client::{
    error::ErrorKind,
    search::search_events::{
        self,
        v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
    },
};

use std::collections::BTreeMap;

/// # `POST /_matrix/client/r0/search`
///
/// Searches rooms for messages.
///
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
pub async fn search_events_route(
    body: Ruma<search_events::v3::Request>,
) -> Result<search_events::v3::Response> {
    let sender_user = body.sender_user.as_ref().expect("user is authenticated");

    let search_criteria = body.search_categories.room_events.as_ref().unwrap();
    let filter = &search_criteria.filter;

    let room_ids = filter.rooms.clone().unwrap_or_else(|| {
        services()
            .rooms
            .state_cache
            .rooms_joined(sender_user)
            .filter_map(|r| r.ok())
            .collect()
    });

    let limit = filter.limit.map_or(10, |l| u64::from(l) as usize);

    let mut searches = Vec::new();

    for room_id in room_ids {
        if !services()
            .rooms
            .state_cache
            .is_joined(sender_user, &room_id)?
        {
            return Err(Error::BadRequest(
                ErrorKind::Forbidden,
                "You don't have permission to view this room.",
            ));
        }

        if let Some(search) = services()
            .rooms
            .search
            .search_pdus(&room_id, &search_criteria.search_term)?
        {
            searches.push(search.0.peekable());
        }
    }

    let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
        Some(Ok(s)) => s,
        Some(Err(_)) => {
            return Err(Error::BadRequest(
                ErrorKind::InvalidParam,
                "Invalid next_batch token.",
            ))
        }
        None => 0, // Default to the start
    };

    let mut results = Vec::new();
    for _ in 0..skip + limit {
        if let Some(s) = searches
            .iter_mut()
            .map(|s| (s.peek().cloned(), s))
            .max_by_key(|(peek, _)| peek.clone())
            .and_then(|(_, i)| i.next())
        {
            results.push(s);
        }
    }

    let results: Vec<_> = results
        .iter()
        .map(|result| {
            Ok::<_, Error>(SearchResult {
                context: EventContextResult {
                    end: None,
                    events_after: Vec::new(),
                    events_before: Vec::new(),
                    profile_info: BTreeMap::new(),
                    start: None,
                },
                rank: None,
                result: services()
                    .rooms
                    .timeline
                    .get_pdu_from_id(result)?
                    .map(|pdu| pdu.to_room_event()),
            })
        })
        .filter_map(|r| r.ok())
        .skip(skip)
        .take(limit)
        .collect();

    let next_batch = if results.len() < limit {
        None
    } else {
        Some((skip + limit).to_string())
    };

    Ok(search_events::v3::Response::new(ResultCategories {
        room_events: ResultRoomEvents {
            count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it
            groups: BTreeMap::new(),                    // TODO
            next_batch,
            results,
            state: BTreeMap::new(), // TODO
            highlights: search_criteria
                .search_term
                .split_terminator(|c: char| !c.is_alphanumeric())
                .map(str::to_lowercase)
                .collect(),
        },
    }))
}