@ -6,7 +6,10 @@ use crate::{
@@ -6,7 +6,10 @@ use crate::{
use get_profile_information ::v1 ::ProfileField ;
use http ::header ::{ HeaderValue , AUTHORIZATION } ;
use regex ::Regex ;
use rocket ::response ::content ::Json ;
use rocket ::{
futures ::{ prelude ::* , stream ::FuturesUnordered } ,
response ::content ::Json ,
} ;
use ruma ::{
api ::{
client ::error ::{ Error as RumaError , ErrorKind } ,
@ -15,8 +18,9 @@ use ruma::{
@@ -15,8 +18,9 @@ use ruma::{
device ::get_devices ::{ self , v1 ::UserDevice } ,
directory ::{ get_public_rooms , get_public_rooms_filtered } ,
discovery ::{
get_remote_server_keys , get_server_keys , get_server_version , ServerSigningKeys ,
VerifyKey ,
get_remote_server_keys , get_remote_server_keys_batch ,
get_remote_server_keys_batch ::v2 ::QueryCriteria , get_server_keys ,
get_server_version , ServerSigningKeys , VerifyKey ,
} ,
event ::{ get_event , get_missing_events , get_room_state , get_room_state_ids } ,
keys ::{ claim_keys , get_keys } ,
@ -35,6 +39,7 @@ use ruma::{
@@ -35,6 +39,7 @@ use ruma::{
} ,
directory ::{ IncomingFilter , IncomingRoomNetwork } ,
events ::{
pdu ::Pdu ,
receipt ::{ ReceiptEvent , ReceiptEventContent } ,
room ::{
create ::CreateEventContent ,
@ -59,7 +64,7 @@ use std::{
@@ -59,7 +64,7 @@ use std::{
net ::{ IpAddr , SocketAddr } ,
pin ::Pin ,
result ::Result as StdResult ,
sync ::{ Arc , RwLock } ,
sync ::{ Arc , RwLock , RwLockWriteGuard } ,
time ::{ Duration , Instant , SystemTime } ,
} ;
use tokio ::sync ::{ MutexGuard , Semaphore } ;
@ -566,7 +571,7 @@ pub fn get_server_keys_route(db: DatabaseGuard) -> Json<String> {
@@ -566,7 +571,7 @@ pub fn get_server_keys_route(db: DatabaseGuard) -> Json<String> {
old_verify_keys : BTreeMap ::new ( ) ,
signatures : BTreeMap ::new ( ) ,
valid_until_ts : MilliSecondsSinceUnixEpoch ::from_system_time (
SystemTime ::now ( ) + Duration ::from_secs ( 60 * 2 ) ,
SystemTime ::now ( ) + Duration ::from_secs ( 8 640 0 * 7 ) ,
)
. expect ( "time is valid" ) ,
} ,
@ -3277,6 +3282,204 @@ pub(crate) async fn fetch_required_signing_keys(
@@ -3277,6 +3282,204 @@ pub(crate) async fn fetch_required_signing_keys(
Ok ( ( ) )
}
// Gets a list of servers for which we don't have the signing key yet. We go over
// the PDUs and either cache the key or add it to the list that needs to be retrieved.
fn get_server_keys_from_cache (
pdu : & Raw < Pdu > ,
servers : & mut BTreeMap < Box < ServerName > , BTreeMap < ServerSigningKeyId , QueryCriteria > > ,
room_version : & RoomVersionId ,
pub_key_map : & mut RwLockWriteGuard < ' _ , BTreeMap < String , BTreeMap < String , String > > > ,
db : & Database ,
) -> Result < ( ) > {
let value = serde_json ::from_str ::< CanonicalJsonObject > ( pdu . json ( ) . get ( ) ) . map_err ( | e | {
error ! ( "Invalid PDU in server response: {:?}: {:?}" , pdu , e ) ;
Error ::BadServerResponse ( "Invalid PDU in server response" )
} ) ? ;
let event_id = EventId ::try_from ( & * format! (
"${}" ,
ruma ::signatures ::reference_hash ( & value , & room_version )
. expect ( "ruma can calculate reference hashes" )
) )
. expect ( "ruma's reference hashes are valid event ids" ) ;
if let Some ( ( time , tries ) ) = db
. globals
. bad_event_ratelimiter
. read ( )
. unwrap ( )
. get ( & event_id )
{
// Exponential backoff
let mut min_elapsed_duration = Duration ::from_secs ( 30 ) * ( * tries ) * ( * tries ) ;
if min_elapsed_duration > Duration ::from_secs ( 60 * 60 * 24 ) {
min_elapsed_duration = Duration ::from_secs ( 60 * 60 * 24 ) ;
}
if time . elapsed ( ) < min_elapsed_duration {
debug ! ( "Backing off from {}" , event_id ) ;
return Err ( Error ::BadServerResponse ( "bad event, still backing off" ) ) ;
}
}
let signatures = value
. get ( "signatures" )
. ok_or ( Error ::BadServerResponse (
"No signatures in server response pdu." ,
) ) ?
. as_object ( )
. ok_or ( Error ::BadServerResponse (
"Invalid signatures object in server response pdu." ,
) ) ? ;
for ( signature_server , signature ) in signatures {
let signature_object = signature . as_object ( ) . ok_or ( Error ::BadServerResponse (
"Invalid signatures content object in server response pdu." ,
) ) ? ;
let signature_ids = signature_object . keys ( ) . cloned ( ) . collect ::< Vec < _ > > ( ) ;
let contains_all_ids =
| keys : & BTreeMap < String , String > | signature_ids . iter ( ) . all ( | id | keys . contains_key ( id ) ) ;
let origin = & Box ::< ServerName > ::try_from ( & * * signature_server ) . map_err ( | _ | {
Error ::BadServerResponse ( "Invalid servername in signatures of server response pdu." )
} ) ? ;
if servers . contains_key ( origin ) | | pub_key_map . contains_key ( origin . as_str ( ) ) {
continue ;
}
trace ! ( "Loading signing keys for {}" , origin ) ;
let result = db
. globals
. signing_keys_for ( origin ) ?
. into_iter ( )
. map ( | ( k , v ) | ( k . to_string ( ) , v . key ) )
. collect ::< BTreeMap < _ , _ > > ( ) ;
if ! contains_all_ids ( & result ) {
trace ! ( "Signing key not loaded for {}" , origin ) ;
servers . insert (
origin . clone ( ) ,
BTreeMap ::< ServerSigningKeyId , QueryCriteria > ::new ( ) ,
) ;
}
pub_key_map . insert ( origin . to_string ( ) , result ) ;
}
Ok ( ( ) )
}
pub ( crate ) async fn fetch_join_signing_keys (
event : & create_join_event ::v2 ::Response ,
room_version : & RoomVersionId ,
pub_key_map : & RwLock < BTreeMap < String , BTreeMap < String , String > > > ,
db : & Database ,
) -> Result < ( ) > {
let mut servers =
BTreeMap ::< Box < ServerName > , BTreeMap < ServerSigningKeyId , QueryCriteria > > ::new ( ) ;
{
let mut pkm = pub_key_map
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ? ;
// Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers`
for pdu in & event . room_state . state {
let _ = get_server_keys_from_cache ( pdu , & mut servers , & room_version , & mut pkm , & db ) ;
}
for pdu in & event . room_state . auth_chain {
let _ = get_server_keys_from_cache ( pdu , & mut servers , & room_version , & mut pkm , & db ) ;
}
drop ( pkm ) ;
}
if servers . is_empty ( ) {
// We had all keys locally
return Ok ( ( ) ) ;
}
for server in db . globals . trusted_servers ( ) {
trace ! ( "Asking batch signing keys from trusted server {}" , server ) ;
if let Ok ( keys ) = db
. sending
. send_federation_request (
& db . globals ,
server ,
get_remote_server_keys_batch ::v2 ::Request {
server_keys : servers . clone ( ) ,
minimum_valid_until_ts : MilliSecondsSinceUnixEpoch ::from_system_time (
SystemTime ::now ( ) + Duration ::from_secs ( 60 ) ,
)
. expect ( "time is valid" ) ,
} ,
)
. await
{
trace ! ( "Got signing keys: {:?}" , keys ) ;
let mut pkm = pub_key_map
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ? ;
for k in keys . server_keys {
// TODO: Check signature from trusted server?
servers . remove ( & k . server_name ) ;
let result = db
. globals
. add_signing_key ( & k . server_name , k . clone ( ) ) ?
. into_iter ( )
. map ( | ( k , v ) | ( k . to_string ( ) , v . key ) )
. collect ::< BTreeMap < _ , _ > > ( ) ;
pkm . insert ( k . server_name . to_string ( ) , result ) ;
}
}
if servers . is_empty ( ) {
return Ok ( ( ) ) ;
}
}
let mut futures = servers
. into_iter ( )
. map ( | ( server , _ ) | async move {
(
db . sending
. send_federation_request (
& db . globals ,
& server ,
get_server_keys ::v2 ::Request ::new ( ) ,
)
. await ,
server ,
)
} )
. collect ::< FuturesUnordered < _ > > ( ) ;
while let Some ( result ) = futures . next ( ) . await {
if let ( Ok ( get_keys_response ) , origin ) = result {
let result = db
. globals
. add_signing_key ( & origin , get_keys_response . server_key . clone ( ) ) ?
. into_iter ( )
. map ( | ( k , v ) | ( k . to_string ( ) , v . key ) )
. collect ::< BTreeMap < _ , _ > > ( ) ;
pub_key_map
. write ( )
. map_err ( | _ | Error ::bad_database ( "RwLock is poisoned." ) ) ?
. insert ( origin . to_string ( ) , result ) ;
}
}
Ok ( ( ) )
}
#[ cfg(test) ]
mod tests {
use super ::{ add_port_to_hostname , get_ip_with_port , FedDest } ;