1mod mock;
8mod readonly;
9
10use std::{collections::HashSet, sync::Arc};
11
12use ruma_common::UserId;
13
14pub use self::{
15    mock::HomeserverConnection as MockHomeserverConnection, readonly::ReadOnlyHomeserverConnection,
16};
17
18#[derive(Debug)]
19pub struct MatrixUser {
20    pub displayname: Option<String>,
21    pub avatar_url: Option<String>,
22    pub deactivated: bool,
23}
24
25#[derive(Debug, Default)]
26enum FieldAction<T> {
27    #[default]
28    DoNothing,
29    Set(T),
30    Unset,
31}
32
33pub struct ProvisionRequest {
34    localpart: String,
35    sub: String,
36    displayname: FieldAction<String>,
37    avatar_url: FieldAction<String>,
38    emails: FieldAction<Vec<String>>,
39}
40
41impl ProvisionRequest {
42    #[must_use]
49    pub fn new(localpart: impl Into<String>, sub: impl Into<String>) -> Self {
50        Self {
51            localpart: localpart.into(),
52            sub: sub.into(),
53            displayname: FieldAction::DoNothing,
54            avatar_url: FieldAction::DoNothing,
55            emails: FieldAction::DoNothing,
56        }
57    }
58
59    #[must_use]
61    pub fn sub(&self) -> &str {
62        &self.sub
63    }
64
65    #[must_use]
67    pub fn localpart(&self) -> &str {
68        &self.localpart
69    }
70
71    #[must_use]
77    pub fn set_displayname(mut self, displayname: String) -> Self {
78        self.displayname = FieldAction::Set(displayname);
79        self
80    }
81
82    #[must_use]
84    pub fn unset_displayname(mut self) -> Self {
85        self.displayname = FieldAction::Unset;
86        self
87    }
88
89    pub fn on_displayname<F>(&self, callback: F) -> &Self
95    where
96        F: FnOnce(Option<&str>),
97    {
98        match &self.displayname {
99            FieldAction::Unset => callback(None),
100            FieldAction::Set(displayname) => callback(Some(displayname)),
101            FieldAction::DoNothing => {}
102        }
103
104        self
105    }
106
107    #[must_use]
113    pub fn set_avatar_url(mut self, avatar_url: String) -> Self {
114        self.avatar_url = FieldAction::Set(avatar_url);
115        self
116    }
117
118    #[must_use]
120    pub fn unset_avatar_url(mut self) -> Self {
121        self.avatar_url = FieldAction::Unset;
122        self
123    }
124
125    pub fn on_avatar_url<F>(&self, callback: F) -> &Self
131    where
132        F: FnOnce(Option<&str>),
133    {
134        match &self.avatar_url {
135            FieldAction::Unset => callback(None),
136            FieldAction::Set(avatar_url) => callback(Some(avatar_url)),
137            FieldAction::DoNothing => {}
138        }
139
140        self
141    }
142
143    #[must_use]
149    pub fn set_emails(mut self, emails: Vec<String>) -> Self {
150        self.emails = FieldAction::Set(emails);
151        self
152    }
153
154    #[must_use]
156    pub fn unset_emails(mut self) -> Self {
157        self.emails = FieldAction::Unset;
158        self
159    }
160
161    pub fn on_emails<F>(&self, callback: F) -> &Self
167    where
168        F: FnOnce(Option<&[String]>),
169    {
170        match &self.emails {
171            FieldAction::Unset => callback(None),
172            FieldAction::Set(emails) => callback(Some(emails)),
173            FieldAction::DoNothing => {}
174        }
175
176        self
177    }
178}
179
180#[async_trait::async_trait]
181pub trait HomeserverConnection: Send + Sync {
182    fn homeserver(&self) -> &str;
184
185    fn mxid(&self, localpart: &str) -> String {
191        format!("@{}:{}", localpart, self.homeserver())
192    }
193
194    fn localpart<'a>(&self, mxid: &'a str) -> Option<&'a str> {
203        let mxid = <&UserId>::try_from(mxid).ok()?;
204        if mxid.server_name() != self.homeserver() {
205            return None;
206        }
207        Some(mxid.localpart())
208    }
209
210    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error>;
223
224    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error>;
235
236    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error>;
248
249    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error>;
259
260    async fn upsert_device(
272        &self,
273        localpart: &str,
274        device_id: &str,
275        initial_display_name: Option<&str>,
276    ) -> Result<(), anyhow::Error>;
277
278    async fn update_device_display_name(
291        &self,
292        localpart: &str,
293        device_id: &str,
294        display_name: &str,
295    ) -> Result<(), anyhow::Error>;
296
297    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error>;
309
310    async fn sync_devices(
322        &self,
323        localpart: &str,
324        devices: HashSet<String>,
325    ) -> Result<(), anyhow::Error>;
326
327    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error>;
339
340    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error>;
351
352    async fn set_displayname(
364        &self,
365        localpart: &str,
366        displayname: &str,
367    ) -> Result<(), anyhow::Error>;
368
369    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error>;
380
381    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error>;
393}
394
395#[async_trait::async_trait]
396impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T {
397    fn homeserver(&self) -> &str {
398        (**self).homeserver()
399    }
400
401    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
402        (**self).verify_token(token).await
403    }
404
405    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
406        (**self).query_user(localpart).await
407    }
408
409    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
410        (**self).provision_user(request).await
411    }
412
413    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
414        (**self).is_localpart_available(localpart).await
415    }
416
417    async fn upsert_device(
418        &self,
419        localpart: &str,
420        device_id: &str,
421        initial_display_name: Option<&str>,
422    ) -> Result<(), anyhow::Error> {
423        (**self)
424            .upsert_device(localpart, device_id, initial_display_name)
425            .await
426    }
427
428    async fn update_device_display_name(
429        &self,
430        localpart: &str,
431        device_id: &str,
432        display_name: &str,
433    ) -> Result<(), anyhow::Error> {
434        (**self)
435            .update_device_display_name(localpart, device_id, display_name)
436            .await
437    }
438
439    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
440        (**self).delete_device(localpart, device_id).await
441    }
442
443    async fn sync_devices(
444        &self,
445        localpart: &str,
446        devices: HashSet<String>,
447    ) -> Result<(), anyhow::Error> {
448        (**self).sync_devices(localpart, devices).await
449    }
450
451    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
452        (**self).delete_user(localpart, erase).await
453    }
454
455    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
456        (**self).reactivate_user(localpart).await
457    }
458
459    async fn set_displayname(
460        &self,
461        localpart: &str,
462        displayname: &str,
463    ) -> Result<(), anyhow::Error> {
464        (**self).set_displayname(localpart, displayname).await
465    }
466
467    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
468        (**self).unset_displayname(localpart).await
469    }
470
471    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
472        (**self).allow_cross_signing_reset(localpart).await
473    }
474}
475
476#[async_trait::async_trait]
478impl<T: HomeserverConnection + ?Sized> HomeserverConnection for Arc<T> {
479    fn homeserver(&self) -> &str {
480        (**self).homeserver()
481    }
482
483    async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
484        (**self).verify_token(token).await
485    }
486
487    async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
488        (**self).query_user(localpart).await
489    }
490
491    async fn provision_user(&self, request: &ProvisionRequest) -> Result<bool, anyhow::Error> {
492        (**self).provision_user(request).await
493    }
494
495    async fn is_localpart_available(&self, localpart: &str) -> Result<bool, anyhow::Error> {
496        (**self).is_localpart_available(localpart).await
497    }
498
499    async fn upsert_device(
500        &self,
501        localpart: &str,
502        device_id: &str,
503        initial_display_name: Option<&str>,
504    ) -> Result<(), anyhow::Error> {
505        (**self)
506            .upsert_device(localpart, device_id, initial_display_name)
507            .await
508    }
509
510    async fn update_device_display_name(
511        &self,
512        localpart: &str,
513        device_id: &str,
514        display_name: &str,
515    ) -> Result<(), anyhow::Error> {
516        (**self)
517            .update_device_display_name(localpart, device_id, display_name)
518            .await
519    }
520
521    async fn delete_device(&self, localpart: &str, device_id: &str) -> Result<(), anyhow::Error> {
522        (**self).delete_device(localpart, device_id).await
523    }
524
525    async fn sync_devices(
526        &self,
527        localpart: &str,
528        devices: HashSet<String>,
529    ) -> Result<(), anyhow::Error> {
530        (**self).sync_devices(localpart, devices).await
531    }
532
533    async fn delete_user(&self, localpart: &str, erase: bool) -> Result<(), anyhow::Error> {
534        (**self).delete_user(localpart, erase).await
535    }
536
537    async fn reactivate_user(&self, localpart: &str) -> Result<(), anyhow::Error> {
538        (**self).reactivate_user(localpart).await
539    }
540
541    async fn set_displayname(
542        &self,
543        localpart: &str,
544        displayname: &str,
545    ) -> Result<(), anyhow::Error> {
546        (**self).set_displayname(localpart, displayname).await
547    }
548
549    async fn unset_displayname(&self, localpart: &str) -> Result<(), anyhow::Error> {
550        (**self).unset_displayname(localpart).await
551    }
552
553    async fn allow_cross_signing_reset(&self, localpart: &str) -> Result<(), anyhow::Error> {
554        (**self).allow_cross_signing_reset(localpart).await
555    }
556}