Browse Source

refactor a bit

merge-requests/195/head
Jonathan de Jong 4 years ago
parent
commit
1ea8a2ceb5
  1. 204
      src/client_server/sync.rs

204
src/client_server/sync.rs

@ -11,7 +11,7 @@ use std::{ @@ -11,7 +11,7 @@ use std::{
sync::Arc,
time::Duration,
};
use tokio::sync::watch::Sender;
use tokio::sync::watch::{Receiver, Sender};
use tracing::error;
#[cfg(feature = "conduit_bin")]
@ -63,6 +63,27 @@ pub async fn sync_events_route( @@ -63,6 +63,27 @@ pub async fn sync_events_route(
let arc_db = Arc::new(db);
fn create_helper(
arc_db: Arc<DatabaseGuard>,
user: &UserId,
device: &Box<DeviceId>,
body: &Ruma<sync_events::Request<'_>>,
) -> Receiver<Option<ConduitResult<sync_events::Response>>> {
let (tx, rx) = tokio::sync::watch::channel(None);
tokio::spawn(sync_helper_wrapper(
arc_db,
user.clone(),
device.clone(),
body.since.clone(),
body.full_state,
body.timeout,
tx,
));
rx
}
let mut rx = match arc_db
.globals
.sync_receivers
@ -71,33 +92,15 @@ pub async fn sync_events_route( @@ -71,33 +92,15 @@ pub async fn sync_events_route(
.entry((sender_user.clone(), sender_device.clone()))
{
Entry::Vacant(v) => {
let (tx, rx) = tokio::sync::watch::channel(None);
tokio::spawn(sync_helper_wrapper(
Arc::clone(&arc_db),
sender_user.clone(),
sender_device.clone(),
body.since.clone(),
body.full_state,
body.timeout,
tx,
));
v.insert((body.since.clone(), rx)).1.clone()
let rx = create_helper(Arc::clone(&arc_db), &sender_user, &sender_device, &body);
v.insert((body.since.clone(), rx.clone()));
rx
}
Entry::Occupied(mut o) => {
if o.get().0 != body.since {
let (tx, rx) = tokio::sync::watch::channel(None);
tokio::spawn(sync_helper_wrapper(
Arc::clone(&arc_db),
sender_user.clone(),
sender_device.clone(),
body.since.clone(),
body.full_state,
body.timeout,
tx,
));
let rx = create_helper(Arc::clone(&arc_db), &sender_user, &sender_device, &body);
o.insert((body.since.clone(), rx.clone()));
@ -137,10 +140,10 @@ async fn sync_helper_wrapper( @@ -137,10 +140,10 @@ async fn sync_helper_wrapper(
tx: Sender<Option<ConduitResult<sync_events::Response>>>,
) {
let r = sync_helper(
Arc::clone(&db),
sender_user.clone(),
sender_device.clone(),
since.clone(),
&db,
&sender_user,
&sender_device,
&since,
full_state,
timeout,
)
@ -171,11 +174,13 @@ async fn sync_helper_wrapper( @@ -171,11 +174,13 @@ async fn sync_helper_wrapper(
let _ = tx.send(Some(r.map(|(r, _)| r.into())));
}
const THIRTY_SECONDS: Duration = Duration::from_secs(30);
async fn sync_helper(
db: Arc<DatabaseGuard>,
sender_user: UserId,
sender_device: Box<DeviceId>,
since: Option<String>,
db: &Database,
sender_user: &UserId,
sender_device: &DeviceId,
since: &Option<String>,
full_state: bool,
timeout: Option<Duration>,
// bool = caching allowed
@ -186,28 +191,68 @@ async fn sync_helper( @@ -186,28 +191,68 @@ async fn sync_helper(
// Setup watchers, so if there's no response, we can wait for them
let watcher = db.watch(&sender_user, &sender_device);
let next_batch = db.globals.current_count()?;
let next_batch_string = next_batch.to_string();
let mut joined_rooms = BTreeMap::new();
let mut next_batch = db.globals.current_count()?;
let since = since
.clone()
.and_then(|string| string.parse().ok())
.unwrap_or(0);
let mut response = collect_response(&db, since, &sender_user, &sender_device, next_batch)?;
// If...
if !full_state // ...the user hasn't requested full-state...
// ...and data between since and next_batch is empty...
&& response.rooms.is_empty()
&& response.presence.is_empty()
&& response.account_data.is_empty()
&& response.to_device.is_empty()
&& response.device_lists.is_empty()
&& response.device_one_time_keys_count.is_empty()
{
// ...we wait until we get new data, or until timeout.
if tokio::time::timeout(timeout.unwrap_or_default().min(THIRTY_SECONDS), watcher)
.await
.is_ok()
// But if we get new data (is_err is timeout)...
{
// ...update the next_batch counter to "now"...
next_batch = db.globals.current_count()?;
// ...generate a new response item, store it...
response = collect_response(&db, since, &sender_user, &sender_device, next_batch)?;
// ...and have it get returned to the user below.
// If the response is empty (for some reason), we know that at least since->(this)next_batch does not contain any data,
// so returning a response with the "now" next_batch is okay.
}
}
// (!=) Only cache if we made progress, else it'll loop around to hit the cached item again.
Ok((response, since != next_batch))
}
fn collect_response(
db: &Database,
since: u64,
user: &UserId,
device: &DeviceId,
next_batch: u64,
) -> std::result::Result<sync_events::Response, Error> {
let mut joined_rooms = BTreeMap::new();
let mut presence_updates = HashMap::new();
let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in
let mut device_list_updates = HashSet::new();
let mut device_list_left = HashSet::new();
let next_batch_string = next_batch.to_string();
// Look for device list updates of this account
device_list_updates.extend(
db.users
.keys_changed(&sender_user.to_string(), since, None)
.keys_changed(&user.to_string(), since, None)
.filter_map(|r| r.ok()),
);
let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>();
let all_joined_rooms = db.rooms.rooms_joined(&user).collect::<Vec<_>>();
for room_id in all_joined_rooms {
let room_id = room_id?;
@ -226,7 +271,7 @@ async fn sync_helper( @@ -226,7 +271,7 @@ async fn sync_helper(
let mut non_timeline_pdus = db
.rooms
.pdus_until(&sender_user, &room_id, u64::MAX)?
.pdus_until(&user, &room_id, u64::MAX)?
.filter_map(|r| {
// Filter out buggy events
if r.is_err() {
@ -250,11 +295,7 @@ async fn sync_helper( @@ -250,11 +295,7 @@ async fn sync_helper(
.collect::<Vec<_>>();
let send_notification_counts = !timeline_pdus.is_empty()
|| db
.rooms
.edus
.last_privateread_update(&sender_user, &room_id)?
> since;
|| db.rooms.edus.last_privateread_update(&user, &room_id)? > since;
// They /sync response doesn't always return all messages, so we say the output is
// limited unless there are events in non_timeline_pdus
@ -283,7 +324,7 @@ async fn sync_helper( @@ -283,7 +324,7 @@ async fn sync_helper(
for hero in db
.rooms
.all_pdus(&sender_user, &room_id)?
.all_pdus(&user, &room_id)?
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
.filter(|(_, pdu)| pdu.kind == EventType::RoomMember)
.map(|(_, pdu)| {
@ -317,7 +358,7 @@ async fn sync_helper( @@ -317,7 +358,7 @@ async fn sync_helper(
// Filter for possible heroes
.flatten()
{
if heroes.contains(&hero) || hero == sender_user.as_str() {
if heroes.contains(&hero) || hero == user.as_str() {
continue;
}
@ -365,11 +406,7 @@ async fn sync_helper( @@ -365,11 +406,7 @@ async fn sync_helper(
let since_sender_member = db
.rooms
.state_get(
since_shortstatehash,
&EventType::RoomMember,
sender_user.as_str(),
)?
.state_get(since_shortstatehash, &EventType::RoomMember, user.as_str())?
.and_then(|pdu| {
serde_json::from_value::<Raw<ruma::events::room::member::MemberEventContent>>(
pdu.content.clone(),
@ -428,7 +465,7 @@ async fn sync_helper( @@ -428,7 +465,7 @@ async fn sync_helper(
let user_id = UserId::try_from(state_key.clone())
.map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?;
if user_id == sender_user {
if user_id == *user {
continue;
}
@ -443,7 +480,7 @@ async fn sync_helper( @@ -443,7 +480,7 @@ async fn sync_helper(
match new_membership {
MembershipState::Join => {
// A new user joined an encrypted room
if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? {
if !share_encrypted_room(&db, &user, &user_id, &room_id)? {
device_list_updates.insert(user_id);
}
}
@ -465,12 +502,11 @@ async fn sync_helper( @@ -465,12 +502,11 @@ async fn sync_helper(
.flatten()
.filter(|user_id| {
// Don't send key updates from the sender to the sender
&sender_user != user_id
user != user_id
})
.filter(|user_id| {
// Only send keys if the sender doesn't share an encrypted room with the target already
!share_encrypted_room(&db, &sender_user, user_id, &room_id)
.unwrap_or(false)
!share_encrypted_room(&db, &user, user_id, &room_id).unwrap_or(false)
}),
);
}
@ -500,7 +536,7 @@ async fn sync_helper( @@ -500,7 +536,7 @@ async fn sync_helper(
let notification_count = if send_notification_counts {
Some(
db.rooms
.notification_count(&sender_user, &room_id)?
.notification_count(&user, &room_id)?
.try_into()
.expect("notification count can't go that high"),
)
@ -511,7 +547,7 @@ async fn sync_helper( @@ -511,7 +547,7 @@ async fn sync_helper(
let highlight_count = if send_notification_counts {
Some(
db.rooms
.highlight_count(&sender_user, &room_id)?
.highlight_count(&user, &room_id)?
.try_into()
.expect("highlight count can't go that high"),
)
@ -558,7 +594,7 @@ async fn sync_helper( @@ -558,7 +594,7 @@ async fn sync_helper(
account_data: sync_events::RoomAccountData {
events: db
.account_data
.changes_since(Some(&room_id), &sender_user, since)?
.changes_since(Some(&room_id), &user, since)?
.into_iter()
.filter_map(|(_, v)| {
serde_json::from_str(v.json().get())
@ -630,7 +666,7 @@ async fn sync_helper( @@ -630,7 +666,7 @@ async fn sync_helper(
}
let mut left_rooms = BTreeMap::new();
let all_left_rooms = db.rooms.rooms_left(&sender_user).collect::<Vec<_>>();
let all_left_rooms = db.rooms.rooms_left(&user).collect::<Vec<_>>();
for result in all_left_rooms {
let (room_id, left_state_events) = result?;
@ -646,7 +682,7 @@ async fn sync_helper( @@ -646,7 +682,7 @@ async fn sync_helper(
let insert_lock = mutex_insert.lock().unwrap();
drop(insert_lock);
let left_count = db.rooms.get_left_count(&room_id, &sender_user)?;
let left_count = db.rooms.get_left_count(&room_id, &user)?;
// Left before last sync
if Some(since) >= left_count {
@ -670,7 +706,7 @@ async fn sync_helper( @@ -670,7 +706,7 @@ async fn sync_helper(
}
let mut invited_rooms = BTreeMap::new();
let all_invited_rooms = db.rooms.rooms_invited(&sender_user).collect::<Vec<_>>();
let all_invited_rooms = db.rooms.rooms_invited(&user).collect::<Vec<_>>();
for result in all_invited_rooms {
let (room_id, invite_state_events) = result?;
@ -686,7 +722,7 @@ async fn sync_helper( @@ -686,7 +722,7 @@ async fn sync_helper(
let insert_lock = mutex_insert.lock().unwrap();
drop(insert_lock);
let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?;
let invite_count = db.rooms.get_invite_count(&room_id, &user)?;
// Invited before last sync
if Some(since) >= invite_count {
@ -706,7 +742,7 @@ async fn sync_helper( @@ -706,7 +742,7 @@ async fn sync_helper(
for user_id in left_encrypted_users {
let still_share_encrypted_room = db
.rooms
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
.get_shared_rooms(vec![user.clone(), user_id.clone()])?
.filter_map(|r| r.ok())
.filter_map(|other_room_id| {
Some(
@ -725,10 +761,9 @@ async fn sync_helper( @@ -725,10 +761,9 @@ async fn sync_helper(
}
// Remove all to-device events the device received *last time*
db.users
.remove_to_device_events(&sender_user, &sender_device, since)?;
db.users.remove_to_device_events(&user, &device, since)?;
let response = sync_events::Response {
Ok(sync_events::Response {
next_batch: next_batch_string,
rooms: sync_events::Rooms {
leave: left_rooms,
@ -745,7 +780,7 @@ async fn sync_helper( @@ -745,7 +780,7 @@ async fn sync_helper(
account_data: sync_events::GlobalAccountData {
events: db
.account_data
.changes_since(None, &sender_user, since)?
.changes_since(None, &user, since)?
.into_iter()
.filter_map(|(_, v)| {
serde_json::from_str(v.json().get())
@ -758,40 +793,17 @@ async fn sync_helper( @@ -758,40 +793,17 @@ async fn sync_helper(
changed: device_list_updates.into_iter().collect(),
left: device_list_left.into_iter().collect(),
},
device_one_time_keys_count: if db.users.last_one_time_keys_update(&sender_user)? > since
device_one_time_keys_count: if db.users.last_one_time_keys_update(&user)? > since
|| since == 0
{
db.users.count_one_time_keys(&sender_user, &sender_device)?
db.users.count_one_time_keys(&user, &device)?
} else {
BTreeMap::new()
},
to_device: sync_events::ToDevice {
events: db
.users
.get_to_device_events(&sender_user, &sender_device)?,
events: db.users.get_to_device_events(&user, &device)?,
},
};
// TODO: Retry the endpoint instead of returning (waiting for #118)
if !full_state
&& response.rooms.is_empty()
&& response.presence.is_empty()
&& response.account_data.is_empty()
&& response.device_lists.is_empty()
&& response.device_one_time_keys_count.is_empty()
&& response.to_device.is_empty()
{
// Hang a few seconds so requests are not spammed
// Stop hanging if new info arrives
let mut duration = timeout.unwrap_or_default();
if duration.as_secs() > 30 {
duration = Duration::from_secs(30);
}
let _ = tokio::time::timeout(duration, watcher).await;
Ok((response, false))
} else {
Ok((response, since != next_batch)) // Only cache if we made progress
}
})
}
#[tracing::instrument(skip(db))]

Loading…
Cancel
Save