Browse Source

use tracing and DatabaseGuard

merge-requests/80/head
Bruno Bigras 4 years ago
parent
commit
e275d2cd49
  1. 10
      src/client_server/session.rs
  2. 18
      src/client_server/sso.rs

10
src/client_server/session.rs

@ -57,15 +57,11 @@ use rocket::{get, post};
/// Get the supported login types of this server. One of these should be used as the `type` field /// Get the supported login types of this server. One of these should be used as the `type` field
/// when logging in. /// when logging in.
#[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/login"))] #[cfg_attr(feature = "conduit_bin", get("/_matrix/client/r0/login"))]
// #[tracing::instrument] // TODO: need Debug on Database #[tracing::instrument(skip(db))]
pub async fn get_login_types_route( pub async fn get_login_types_route(db: DatabaseGuard) -> ConduitResult<get_login_types::Response> {
db: &rocket::State<Arc<RwLock<Database>>>,
) -> ConduitResult<get_login_types::Response> {
let mut flows = vec![get_login_types::LoginType::Password(Default::default())]; let mut flows = vec![get_login_types::LoginType::Password(Default::default())];
let db_lock = db.read().await; if db.globals.openid_client.is_some() {
if db_lock.globals.openid_client.is_some() {
flows.push(get_login_types::LoginType::Sso( flows.push(get_login_types::LoginType::Sso(
get_login_types::SsoLoginType::default(), get_login_types::SsoLoginType::default(),
)); ));

18
src/client_server/sso.rs

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
// use super::State; // use super::State;
use crate::{server_server, ConduitResult, Database, Error, Ruma}; use crate::{database::DatabaseGuard, server_server, ConduitResult, Database, Error, Ruma};
use http::status; use http::status;
use macaroon::Macaroon; use macaroon::Macaroon;
use openid::{Token, Userinfo}; use openid::{Token, Userinfo};
@ -22,13 +22,11 @@ const MAC_VALID_SECS: i64 = 10;
get("/_matrix/client/r0/login/sso/redirect?<redirectUrl>") get("/_matrix/client/r0/login/sso/redirect?<redirectUrl>")
)] )]
pub async fn get_sso_redirect( pub async fn get_sso_redirect(
db: &rocket::State<Arc<RwLock<Database>>>, db: DatabaseGuard,
redirectUrl: &str, redirectUrl: &str,
mut cookies: &CookieJar<'_>, mut cookies: &CookieJar<'_>,
) -> Redirect { ) -> Redirect {
let db_lock = db.read().await; let (_key, client) = db.globals.openid_client.as_ref().unwrap();
let (_key, client) = db_lock.globals.openid_client.as_ref().unwrap();
let state = "value"; // TODO: random let state = "value"; // TODO: random
@ -102,9 +100,9 @@ pub enum ExampleResponse<'a> {
feature = "conduit_bin", feature = "conduit_bin",
get("/sso_return?<session_state>&<state>&<code>") get("/sso_return?<session_state>&<state>&<code>")
)] )]
// #[tracing::instrument] #[tracing::instrument(skip(db))]
pub async fn get_sso_return<'a>( pub async fn get_sso_return<'a>(
db: &rocket::State<Arc<RwLock<Database>>>, db: DatabaseGuard,
session_state: &str, session_state: &str,
state: &str, state: &str,
code: &str, code: &str,
@ -119,9 +117,7 @@ pub async fn get_sso_return<'a>(
))); )));
} }
let db_lock = db.read().await; let (_key, client) = db.globals.openid_client.as_ref().unwrap();
let (_key, client) = db_lock.globals.openid_client.as_ref().unwrap();
let username; let username;
match request_token(client, code).await { match request_token(client, code).await {
@ -161,7 +157,7 @@ pub async fn get_sso_return<'a>(
} }
} }
let (key, _client) = db_lock.globals.openid_client.as_ref().unwrap(); let (key, _client) = db.globals.openid_client.as_ref().unwrap();
// Create our macaroon // Create our macaroon
let mut macaroon = match Macaroon::create(Some("location".into()), &key, username.into()) { let mut macaroon = match Macaroon::create(Some("location".into()), &key, username.into()) {

Loading…
Cancel
Save