@ -1,7 +1,7 @@
@@ -1,7 +1,7 @@
use rocket ::{
data ::{ FromDataSimple , Outcom e} ,
data ::{ Data , FromData , FromDataFuture , Transform , Transformed , TransformFutur e} ,
http ::Status ,
response ::Responder ,
response ::{ self , Responder } ,
Outcome ::* ,
Request , State ,
} ;
@ -13,9 +13,10 @@ use ruma_client_api::error::Error;
@@ -13,9 +13,10 @@ use ruma_client_api::error::Error;
use ruma_identifiers ::UserId ;
use std ::{
convert ::{ TryFrom , TryInto } ,
io ::{ Cursor , Read } ,
io ::Cursor ,
ops ::Deref ,
} ;
use tokio ::io ::AsyncReadExt ;
const MESSAGE_LIMIT : u64 = 65535 ;
@ -27,7 +28,7 @@ pub struct Ruma<T: Outgoing> {
@@ -27,7 +28,7 @@ pub struct Ruma<T: Outgoing> {
pub json_body : serde_json ::Value ,
}
impl < T : Endpoint > FromDataSimple for Ruma < T >
impl < ' a , T : Endpoint > FromData < ' a > for Ruma < T >
where
// We need to duplicate Endpoint's where clauses because the compiler is not smart enough yet.
// See https://github.com/rust-lang/rust/issues/54149
@ -38,63 +39,76 @@ where
@@ -38,63 +39,76 @@ where
> ,
{
type Error = ( ) ; // TODO: Better error handling
type Owned = Data ;
type Borrowed = Self ::Owned ;
fn from_data ( request : & Request , data : rocket ::Data ) -> Outcome < Self , Self ::Error > {
let user_id = if T ::METADATA . requires_authentication {
let data = request . guard ::< State < crate ::Data > > ( ) . unwrap ( ) ;
// Get token from header or query value
let token = match request
. headers ( )
. get_one ( "Authorization" )
. map ( | s | s . to_owned ( ) )
. or_else ( | | request . get_query_value ( "access_token" ) . and_then ( | r | r . ok ( ) ) )
{
// TODO: M_MISSING_TOKEN
None = > return Failure ( ( Status ::Unauthorized , ( ) ) ) ,
Some ( token ) = > token ,
fn transform < ' r > ( _req : & ' r Request , data : Data ) -> TransformFuture < ' r , Self ::Owned , Self ::Error > {
Box ::pin ( async move { Transform ::Owned ( Success ( data ) ) } )
}
fn from_data (
request : & ' a Request ,
outcome : Transformed < ' a , Self > ,
) -> FromDataFuture < ' a , Self , Self ::Error > {
Box ::pin ( async move {
let data = rocket ::try_outcome ! ( outcome . owned ( ) ) ;
let user_id = if T ::METADATA . requires_authentication {
let data = request . guard ::< State < crate ::Data > > ( ) . await . unwrap ( ) ;
// Get token from header or query value
let token = match request
. headers ( )
. get_one ( "Authorization" )
. map ( | s | s . to_owned ( ) )
. or_else ( | | request . get_query_value ( "access_token" ) . and_then ( | r | r . ok ( ) ) )
{
// TODO: M_MISSING_TOKEN
None = > return Failure ( ( Status ::Unauthorized , ( ) ) ) ,
Some ( token ) = > token ,
} ;
// Check if token is valid
match data . user_from_token ( & token ) {
// TODO: M_UNKNOWN_TOKEN
None = > return Failure ( ( Status ::Unauthorized , ( ) ) ) ,
Some ( user_id ) = > Some ( user_id ) ,
}
} else {
None
} ;
// Check if token is valid
match data . user_from_token ( & token ) {
// TODO: M_UNKNOWN_TOKEN
None = > return Failure ( ( Status ::Unauthorized , ( ) ) ) ,
Some ( user_id ) = > Some ( user_id ) ,
let mut http_request = http ::Request ::builder ( )
. uri ( request . uri ( ) . to_string ( ) )
. method ( & * request . method ( ) . to_string ( ) ) ;
for header in request . headers ( ) . iter ( ) {
http_request = http_request . header ( header . name . as_str ( ) , & * header . value ) ;
}
} else {
None
} ;
let mut http_request = http ::Request ::builder ( )
. uri ( request . uri ( ) . to_string ( ) )
. method ( & * request . method ( ) . to_string ( ) ) ;
for header in request . headers ( ) . iter ( ) {
http_request = http_request . header ( header . name . as_str ( ) , & * header . value ) ;
}
let mut handle = data . open ( ) . take ( MESSAGE_LIMIT ) ;
let mut body = Vec ::new ( ) ;
handle . read_to_end ( & mut body ) . unwrap ( ) ;
let http_request = http_request . body ( body . clone ( ) ) . unwrap ( ) ;
log ::info ! ( "{:?}" , http_request ) ;
match T ::Incoming ::try_from ( http_request ) {
Ok ( t ) = > Success ( Ruma {
body : t ,
user_id ,
// TODO: Can we avoid parsing it again?
json_body : if ! body . is_empty ( ) {
serde_json ::from_slice ( & body ) . expect ( "Ruma already parsed it successfully" )
} else {
serde_json ::Value ::default ( )
} ,
} ) ,
Err ( e ) = > {
log ::error ! ( "{:?}" , e ) ;
Failure ( ( Status ::InternalServerError , ( ) ) )
let mut handle = data . open ( ) . take ( MESSAGE_LIMIT ) ;
let mut body = Vec ::new ( ) ;
handle . read_to_end ( & mut body ) . await . unwrap ( ) ;
let http_request = http_request . body ( body . clone ( ) ) . unwrap ( ) ;
log ::info ! ( "{:?}" , http_request ) ;
match T ::Incoming ::try_from ( http_request ) {
Ok ( t ) = > Success ( Ruma {
body : t ,
user_id ,
// TODO: Can we avoid parsing it again?
json_body : if ! body . is_empty ( ) {
serde_json ::from_slice ( & body ) . expect ( "Ruma already parsed it successfully" )
} else {
serde_json ::Value ::default ( )
} ,
} ) ,
Err ( e ) = > {
log ::error ! ( "{:?}" , e ) ;
Failure ( ( Status ::InternalServerError , ( ) ) )
}
}
}
} )
}
}
@ -108,7 +122,9 @@ impl<T: Outgoing> Deref for Ruma<T> {
@@ -108,7 +122,9 @@ impl<T: Outgoing> Deref for Ruma<T> {
/// This struct converts ruma responses into rocket http responses.
pub struct MatrixResult < T > ( pub std ::result ::Result < T , Error > ) ;
impl < T : TryInto < http ::Response < Vec < u8 > > > > TryInto < http ::Response < Vec < u8 > > > for MatrixResult < T > {
impl < T : TryInto < http ::Response < Vec < u8 > > > > TryInto < http ::Response < Vec < u8 > > > for MatrixResult < T >
{
type Error = T ::Error ;
fn try_into ( self ) -> Result < http ::Response < Vec < u8 > > , T ::Error > {
@ -119,13 +135,14 @@ impl<T: TryInto<http::Response<Vec<u8>>>> TryInto<http::Response<Vec<u8>>> for M
@@ -119,13 +135,14 @@ impl<T: TryInto<http::Response<Vec<u8>>>> TryInto<http::Response<Vec<u8>>> for M
}
}
impl < ' r , T : TryInto < http ::Response < Vec < u8 > > > > Responder < ' r > for MatrixResult < T > {
fn respond_to ( self , _ : & Request ) -> rocket ::response ::Result < ' r > {
#[ rocket::async_trait ]
impl < ' r , T : Send + TryInto < http ::Response < Vec < u8 > > > > Responder < ' r > for MatrixResult < T > where T ::Error : Send {
async fn respond_to ( self , _ : & ' r Request < ' _ > ) -> response ::Result < ' r > {
let http_response : Result < http ::Response < _ > , _ > = self . try_into ( ) ;
match http_response {
Ok ( http_response ) = > {
let mut response = rocket ::response ::Response ::build ( ) ;
response . sized_body ( Cursor ::new ( http_response . body ( ) . clone ( ) ) ) ;
response . sized_body ( Cursor ::new ( http_response . body ( ) . clone ( ) ) ) . await ;
for header in http_response . headers ( ) {
response