From fcefa346f28bc1c622069499f3cf926accecea49 Mon Sep 17 00:00:00 2001 From: Kurt Roeckx Date: Sun, 29 Aug 2021 13:25:20 +0200 Subject: [PATCH] fixup! Get required keys in batch when joining a room --- src/database/globals.rs | 15 +++++++++++++-- src/server_server.rs | 42 ++++++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/database/globals.rs b/src/database/globals.rs index 823ce34..24a8ac2 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -273,7 +273,11 @@ impl Globals { /// Remove the outdated keys and insert the new ones. /// /// This doesn't actually check that the keys provided are newer than the old set. - pub fn add_signing_key(&self, origin: &ServerName, new_keys: ServerSigningKeys) -> Result<()> { + pub fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result> { // Not atomic, but this is not critical let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; @@ -298,7 +302,14 @@ impl Globals { &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), )?; - Ok(()) + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + + Ok(tree) } /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. diff --git a/src/server_server.rs b/src/server_server.rs index 6908518..f034fa2 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -2793,13 +2793,18 @@ pub async fn fetch_required_signing_keys( Ok(()) } -pub fn get_missing_signing_keys_for_pdus( +// Gets a list of servers for which we don't have the signing key yet. We go over +// the PDUs and either cache the key or add it to the list that needs to be retrieved. +fn get_missing_servers_for_pdus( pdus: &Vec>, servers: &mut BTreeMap, BTreeMap>, room_version: &RoomVersionId, pub_key_map: &RwLock>>, db: &Database, ) -> Result<()> { + let mut pkm = pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?; for pdu in pdus { let value = serde_json::from_str::(pdu.json().get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -2856,6 +2861,10 @@ pub fn get_missing_signing_keys_for_pdus( Error::BadServerResponse("Invalid servername in signatures of server response pdu.") })?; + if servers.contains_key(origin) { + continue; + } + trace!("Loading signing keys for {}", origin); let result = db @@ -2873,10 +2882,7 @@ pub fn get_missing_signing_keys_for_pdus( ); } - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(origin.to_string(), result); + pkm.insert(origin.to_string(), result); } } @@ -2892,14 +2898,14 @@ pub async fn fetch_join_signing_keys( let mut servers = BTreeMap::, BTreeMap>::new(); - get_missing_signing_keys_for_pdus( + get_missing_servers_for_pdus( &event.room_state.state, &mut servers, &room_version, &pub_key_map, &db, )?; - get_missing_signing_keys_for_pdus( + get_missing_servers_for_pdus( &event.room_state.auth_chain, &mut servers, &room_version, @@ -2938,23 +2944,19 @@ pub async fn fetch_join_signing_keys( .await { trace!("Got signing keys: {:?}", keys); + let mut pkm = pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))?; for k in keys.server_keys { // TODO: Check signature servers.remove(&k.server_name); - db.globals.add_signing_key(&k.server_name, k.clone())?; - - let result = db - .globals - .signing_keys_for(&k.server_name)? + let result = db.globals.add_signing_key(&k.server_name, k.clone())? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>(); - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(k.server_name.to_string(), result); + pkm.insert(k.server_name.to_string(), result); } } if servers.is_empty() { @@ -2971,12 +2973,8 @@ pub async fn fetch_join_signing_keys( if let Ok(get_keys_response) = result { // TODO: We should probably not trust the server_name in the response. let server = &get_keys_response.server_key.server_name; - db.globals - .add_signing_key(server, get_keys_response.server_key.clone())?; - - let result = db - .globals - .signing_keys_for(server)? + let result = db.globals + .add_signing_key(server, get_keys_response.server_key.clone())? .into_iter() .map(|(k, v)| (k.to_string(), v.key)) .collect::>();