@ -6,7 +6,10 @@ use crate::{
@@ -6,7 +6,10 @@ use crate::{
use get_profile_information ::v1 ::ProfileField ;
use http ::header ::{ HeaderValue , AUTHORIZATION } ;
use regex ::Regex ;
use rocket ::{ futures , response ::content ::Json } ;
use rocket ::{
futures ::{ prelude ::* , stream ::FuturesUnordered } ,
response ::content ::Json ,
} ;
use ruma ::{
api ::{
client ::error ::{ Error as RumaError , ErrorKind } ,
@ -61,7 +64,7 @@ use std::{
@@ -61,7 +64,7 @@ use std::{
net ::{ IpAddr , SocketAddr } ,
pin ::Pin ,
result ::Result as StdResult ,
sync ::{ Arc , RwLock } ,
sync ::{ Arc , RwLock , RwLockWriteGuard } ,
time ::{ Duration , Instant , SystemTime } ,
} ;
use tokio ::sync ::{ MutexGuard , Semaphore } ;
@ -3281,101 +3284,96 @@ pub(crate) async fn fetch_required_signing_keys(
@@ -3281,101 +3284,96 @@ pub(crate) async fn fetch_required_signing_keys(
// Gets a list of servers for which we don't have the signing key yet. We go over
// the PDUs and either cache the key or add it to the list that needs to be retrieved.
fn get_missing_servers_for_pdus (
pdus : & Vec < Raw < Pdu > > ,
fn get_server_keys_from_cache (
pdu : & Raw < Pdu > ,
servers : & mut BTreeMap < Box < ServerName > , BTreeMap < ServerSigningKeyId , QueryCriteria > > ,
room_version : & RoomVersionId ,
pub_key_map : & RwLock < BTreeMap < String , BTreeMap < String , String > > > ,
pub_key_map : & mut RwLockWriteGuard < ' _ , BTreeMap < String , BTreeMap < String , String > > > ,
db : & Database ,
) -> Result < ( ) > {
let mut pkm = pub_key_map
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ? ;
for pdu in pdus {
let value = serde_json ::from_str ::< CanonicalJsonObject > ( pdu . json ( ) . get ( ) ) . map_err ( | e | {
error ! ( "Invalid PDU in server response: {:?}: {:?}" , pdu , e ) ;
Error ::BadServerResponse ( "Invalid PDU in server response" )
} ) ? ;
let event_id = EventId ::try_from ( & * format! (
"${}" ,
ruma ::signatures ::reference_hash ( & value , & room_version )
. expect ( "ruma can calculate reference hashes" )
) )
. expect ( "ruma's reference hashes are valid event ids" ) ;
if let Some ( ( time , tries ) ) = db
. globals
. bad_event_ratelimiter
. read ( )
. unwrap ( )
. get ( & event_id )
{
// Exponential backoff
let mut min_elapsed_duration = Duration ::from_secs ( 30 ) * ( * tries ) * ( * tries ) ;
if min_elapsed_duration > Duration ::from_secs ( 60 * 60 * 24 ) {
min_elapsed_duration = Duration ::from_secs ( 60 * 60 * 24 ) ;
}
let value = serde_json ::from_str ::< CanonicalJsonObject > ( pdu . json ( ) . get ( ) ) . map_err ( | e | {
error ! ( "Invalid PDU in server response: {:?}: {:?}" , pdu , e ) ;
Error ::BadServerResponse ( "Invalid PDU in server response" )
} ) ? ;
if time . elapsed ( ) < min_elapsed_duration {
debug ! ( "Backing off from {}" , event_id ) ;
return Err ( Error ::BadServerResponse ( "bad event, still backing off" ) ) ;
}
let event_id = EventId ::try_from ( & * format! (
"${}" ,
ruma ::signatures ::reference_hash ( & value , & room_version )
. expect ( "ruma can calculate reference hashes" )
) )
. expect ( "ruma's reference hashes are valid event ids" ) ;
if let Some ( ( time , tries ) ) = db
. globals
. bad_event_ratelimiter
. read ( )
. unwrap ( )
. get ( & event_id )
{
// Exponential backoff
let mut min_elapsed_duration = Duration ::from_secs ( 30 ) * ( * tries ) * ( * tries ) ;
if min_elapsed_duration > Duration ::from_secs ( 60 * 60 * 24 ) {
min_elapsed_duration = Duration ::from_secs ( 60 * 60 * 24 ) ;
}
let signatures = value
. get ( "signatures" )
. ok_or ( Error ::BadServerResponse (
"No signatures in server response pdu." ,
) ) ?
. as_object ( )
. ok_or ( Error ::BadServerResponse (
"Invalid signatures object in server response pdu." ,
) ) ? ;
if time . elapsed ( ) < min_elapsed_duration {
debug ! ( "Backing off from {}" , event_id ) ;
return Err ( Error ::BadServerResponse ( "bad event, still backing off" ) ) ;
}
}
for ( signature_server , signature ) in signatures {
let signature_object = signature . as_object ( ) . ok_or ( Error ::BadServerResponse (
"Invalid signatures content object in server response pdu." ,
) ) ? ;
let signatures = value
. get ( "signatures" )
. ok_or ( Error ::BadServerResponse (
"No signatures in server response pdu." ,
) ) ?
. as_object ( )
. ok_or ( Error ::BadServerResponse (
"Invalid signatures object in server response pdu." ,
) ) ? ;
let signature_ids = signature_object . keys ( ) . cloned ( ) . collect ::< Vec < _ > > ( ) ;
for ( signature_server , signature ) in signatures {
let signature_object = signature . as_object ( ) . ok_or ( Error ::BadServerResponse (
"Invalid signatures content object in server response pdu." ,
) ) ? ;
let contains_all_ids = | keys : & BTreeMap < String , String > | {
signature_ids . iter ( ) . all ( | id | keys . contains_key ( id ) )
} ;
let signature_ids = signature_object . keys ( ) . cloned ( ) . collect ::< Vec < _ > > ( ) ;
let origin = & Box ::< ServerName > ::try_from ( & * * signature_server ) . map_err ( | _ | {
Error ::BadServerResponse ( "Invalid servername in signatures of server response pdu." )
} ) ? ;
let contains_all_ids =
| keys : & BTreeMap < String , String > | signature_ids . iter ( ) . all ( | id | keys . contains_key ( id ) ) ;
if servers . contains_key ( origin ) {
continue ;
}
let origin = & Box ::< ServerName > ::try_from ( & * * signature_server ) . map_err ( | _ | {
Error ::BadServerResponse ( "Invalid servername in signatures of server response pdu." )
} ) ? ;
trace ! ( "Loading signing keys for {}" , origin ) ;
if servers . contains_key ( origin ) | | pub_key_map . contains_key ( origin . as_str ( ) ) {
continue ;
}
let result = db
. globals
. signing_keys_for ( origin ) ?
. into_iter ( )
. map ( | ( k , v ) | ( k . to_string ( ) , v . key ) )
. collect ::< BTreeMap < _ , _ > > ( ) ;
trace ! ( "Loading signing keys for {}" , origin ) ;
if ! contains_all_ids ( & result ) {
trace ! ( "Signing key not loaded for {}" , origin ) ;
servers . insert (
origin . clone ( ) ,
BTreeMap ::< ServerSigningKeyId , QueryCriteria > ::new ( ) ,
) ;
}
let result = db
. globals
. signing_keys_for ( origin ) ?
. into_iter ( )
. map ( | ( k , v ) | ( k . to_string ( ) , v . key ) )
. collect ::< BTreeMap < _ , _ > > ( ) ;
pkm . insert ( origin . to_string ( ) , result ) ;
if ! contains_all_ids ( & result ) {
trace ! ( "Signing key not loaded for {}" , origin ) ;
servers . insert (
origin . clone ( ) ,
BTreeMap ::< ServerSigningKeyId , QueryCriteria > ::new ( ) ,
) ;
}
pub_key_map . insert ( origin . to_string ( ) , result ) ;
}
Ok ( ( ) )
}
pub async fn fetch_join_signing_keys (
pub ( crate ) async fn fetch_join_signing_keys (
event : & create_join_event ::v2 ::Response ,
room_version : & RoomVersionId ,
pub_key_map : & RwLock < BTreeMap < String , BTreeMap < String , String > > > ,
@ -3384,32 +3382,26 @@ pub async fn fetch_join_signing_keys(
@@ -3384,32 +3382,26 @@ pub async fn fetch_join_signing_keys(
let mut servers =
BTreeMap ::< Box < ServerName > , BTreeMap < ServerSigningKeyId , QueryCriteria > > ::new ( ) ;
get_missing_servers_for_pdus (
& event . room_state . state ,
& mut servers ,
& room_version ,
& pub_key_map ,
& db ,
) ? ;
get_missing_servers_for_pdus (
& event . room_state . auth_chain ,
& mut servers ,
& room_version ,
& pub_key_map ,
& db ,
) ? ;
{
let mut pkm = pub_key_map
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ? ;
if servers . is_empty ( ) {
return Ok ( ( ) ) ;
// Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers`
for pdu in & event . room_state . state {
let _ = get_server_keys_from_cache ( pdu , & mut servers , & room_version , & mut pkm , & db ) ;
}
for pdu in & event . room_state . auth_chain {
let _ = get_server_keys_from_cache ( pdu , & mut servers , & room_version , & mut pkm , & db ) ;
}
drop ( pkm ) ;
}
for server in db . globals . trusted_servers ( ) {
if db . globals . signing_keys_for ( server ) ? . is_empty ( ) {
servers . insert (
server . clone ( ) ,
BTreeMap ::< ServerSigningKeyId , QueryCriteria > ::new ( ) ,
) ;
}
if servers . is_empty ( ) {
// We had all keys locally
return Ok ( ( ) ) ;
}
for server in db . globals . trusted_servers ( ) {
@ -3434,7 +3426,7 @@ pub async fn fetch_join_signing_keys(
@@ -3434,7 +3426,7 @@ pub async fn fetch_join_signing_keys(
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ? ;
for k in keys . server_keys {
// TODO: Check signature
// TODO: Check signature from trusted server?
servers . remove ( & k . server_name ) ;
let result = db
@ -3447,23 +3439,33 @@ pub async fn fetch_join_signing_keys(
@@ -3447,23 +3439,33 @@ pub async fn fetch_join_signing_keys(
pkm . insert ( k . server_name . to_string ( ) , result ) ;
}
}
if servers . is_empty ( ) {
return Ok ( ( ) ) ;
}
}
for result in futures ::future ::join_all ( servers . iter ( ) . map ( | ( server , _ ) | {
db . sending
. send_federation_request ( & db . globals , server , get_server_keys ::v2 ::Request ::new ( ) )
} ) )
. await
{
if let Ok ( get_keys_response ) = result {
// TODO: We should probably not trust the server_name in the response.
let server = & get_keys_response . server_key . server_name ;
let mut futures = servers
. into_iter ( )
. map ( | ( server , _ ) | async move {
(
db . sending
. send_federation_request (
& db . globals ,
& server ,
get_server_keys ::v2 ::Request ::new ( ) ,
)
. await ,
server ,
)
} )
. collect ::< FuturesUnordered < _ > > ( ) ;
while let Some ( result ) = futures . next ( ) . await {
if let ( Ok ( get_keys_response ) , origin ) = result {
let result = db
. globals
. add_signing_key ( server , get_keys_response . server_key . clone ( ) ) ?
. add_signing_key ( & origin , get_keys_response . server_key . clone ( ) ) ?
. into_iter ( )
. map ( | ( k , v ) | ( k . to_string ( ) , v . key ) )
. collect ::< BTreeMap < _ , _ > > ( ) ;
@ -3471,7 +3473,7 @@ pub async fn fetch_join_signing_keys(
@@ -3471,7 +3473,7 @@ pub async fn fetch_join_signing_keys(
pub_key_map
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ?
. insert ( server . to_string ( ) , result ) ;
. insert ( origin . to_string ( ) , result ) ;
}
}