1use std::{collections::HashMap, sync::Arc};
8
9use mas_context::LogContext;
10use mas_data_model::{
11    UpstreamOAuthProvider, UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderPkceMode,
12};
13use mas_iana::oauth::PkceCodeChallengeMethod;
14use mas_oidc_client::error::DiscoveryError;
15use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
16use oauth2_types::oidc::VerifiedProviderMetadata;
17use tokio::sync::RwLock;
18use url::Url;
19
20pub struct LazyProviderInfos<'a> {
23    cache: &'a MetadataCache,
24    provider: &'a UpstreamOAuthProvider,
25    client: &'a reqwest::Client,
26    loaded_metadata: Option<Arc<VerifiedProviderMetadata>>,
27}
28
29impl<'a> LazyProviderInfos<'a> {
30    pub fn new(
31        cache: &'a MetadataCache,
32        provider: &'a UpstreamOAuthProvider,
33        client: &'a reqwest::Client,
34    ) -> Self {
35        Self {
36            cache,
37            provider,
38            client,
39            loaded_metadata: None,
40        }
41    }
42
43    pub async fn maybe_discover(
46        &mut self,
47    ) -> Result<Option<&VerifiedProviderMetadata>, DiscoveryError> {
48        match self.load().await {
49            Ok(metadata) => Ok(Some(metadata)),
50            Err(DiscoveryError::Disabled) => Ok(None),
51            Err(e) => Err(e),
52        }
53    }
54
55    async fn load(&mut self) -> Result<&VerifiedProviderMetadata, DiscoveryError> {
56        if self.loaded_metadata.is_none() {
57            let verify = match self.provider.discovery_mode {
58                UpstreamOAuthProviderDiscoveryMode::Oidc => true,
59                UpstreamOAuthProviderDiscoveryMode::Insecure => false,
60                UpstreamOAuthProviderDiscoveryMode::Disabled => {
61                    return Err(DiscoveryError::Disabled);
62                }
63            };
64
65            let Some(issuer) = &self.provider.issuer else {
66                return Err(DiscoveryError::MissingIssuer);
67            };
68
69            let metadata = self.cache.get(self.client, issuer, verify).await?;
70
71            self.loaded_metadata = Some(metadata);
72        }
73
74        Ok(self.loaded_metadata.as_ref().unwrap())
75    }
76
77    pub async fn jwks_uri(&mut self) -> Result<&Url, DiscoveryError> {
82        if let Some(jwks_uri) = &self.provider.jwks_uri_override {
83            return Ok(jwks_uri);
84        }
85
86        Ok(self.load().await?.jwks_uri())
87    }
88
89    pub async fn authorization_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
94        if let Some(authorization_endpoint) = &self.provider.authorization_endpoint_override {
95            return Ok(authorization_endpoint);
96        }
97
98        Ok(self.load().await?.authorization_endpoint())
99    }
100
101    pub async fn token_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
106        if let Some(token_endpoint) = &self.provider.token_endpoint_override {
107            return Ok(token_endpoint);
108        }
109
110        Ok(self.load().await?.token_endpoint())
111    }
112
113    pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
118        if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
119            return Ok(userinfo_endpoint);
120        }
121
122        Ok(self.load().await?.userinfo_endpoint())
123    }
124
125    pub async fn pkce_methods(
130        &mut self,
131    ) -> Result<Option<Vec<PkceCodeChallengeMethod>>, DiscoveryError> {
132        let methods = match self.provider.pkce_mode {
133            UpstreamOAuthProviderPkceMode::Auto => self
134                .maybe_discover()
135                .await?
136                .and_then(|metadata| metadata.code_challenge_methods_supported.clone()),
137            UpstreamOAuthProviderPkceMode::S256 => Some(vec![PkceCodeChallengeMethod::S256]),
138            UpstreamOAuthProviderPkceMode::Disabled => None,
139        };
140
141        Ok(methods)
142    }
143}
144
145#[allow(clippy::module_name_repetitions)]
151#[derive(Debug, Clone, Default)]
152pub struct MetadataCache {
153    cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
154    insecure_cache: Arc<RwLock<HashMap<String, Arc<VerifiedProviderMetadata>>>>,
155}
156
157impl MetadataCache {
158    #[must_use]
159    pub fn new() -> Self {
160        Self::default()
161    }
162
163    #[tracing::instrument(name = "metadata_cache.warm_up_and_run", skip_all)]
173    pub async fn warm_up_and_run<R: RepositoryAccess>(
174        &self,
175        client: &reqwest::Client,
176        interval: std::time::Duration,
177        repository: &mut R,
178    ) -> Result<tokio::task::JoinHandle<()>, R::Error> {
179        let providers = repository.upstream_oauth_provider().all_enabled().await?;
180
181        for provider in providers {
182            let verify = match provider.discovery_mode {
183                UpstreamOAuthProviderDiscoveryMode::Oidc => true,
184                UpstreamOAuthProviderDiscoveryMode::Insecure => false,
185                UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
186            };
187
188            let Some(issuer) = &provider.issuer else {
189                tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
190                continue;
191            };
192
193            if let Err(e) = self.fetch(client, issuer, verify).await {
194                tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
195            }
196        }
197
198        let cache = self.clone();
200        let client = client.clone();
201        Ok(tokio::spawn(async move {
202            loop {
203                tokio::time::sleep(interval).await;
205                LogContext::new("metadata-cache-refresh")
206                    .run(|| cache.refresh_all(&client))
207                    .await;
208            }
209        }))
210    }
211
212    #[tracing::instrument(name = "metadata_cache.fetch", fields(%issuer), skip_all)]
213    async fn fetch(
214        &self,
215        client: &reqwest::Client,
216        issuer: &str,
217        verify: bool,
218    ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
219        if verify {
220            let metadata = mas_oidc_client::requests::discovery::discover(client, issuer).await?;
221            let metadata = Arc::new(metadata);
222
223            self.cache
224                .write()
225                .await
226                .insert(issuer.to_owned(), metadata.clone());
227
228            Ok(metadata)
229        } else {
230            let metadata =
231                mas_oidc_client::requests::discovery::insecure_discover(client, issuer).await?;
232            let metadata = Arc::new(metadata);
233
234            self.insecure_cache
235                .write()
236                .await
237                .insert(issuer.to_owned(), metadata.clone());
238
239            Ok(metadata)
240        }
241    }
242
243    #[tracing::instrument(name = "metadata_cache.get", fields(%issuer), skip_all)]
249    pub async fn get(
250        &self,
251        client: &reqwest::Client,
252        issuer: &str,
253        verify: bool,
254    ) -> Result<Arc<VerifiedProviderMetadata>, DiscoveryError> {
255        let cache = if verify {
256            self.cache.read().await
257        } else {
258            self.insecure_cache.read().await
259        };
260
261        if let Some(metadata) = cache.get(issuer) {
262            return Ok(Arc::clone(metadata));
263        }
264        drop(cache);
266
267        let metadata = self.fetch(client, issuer, verify).await?;
268        Ok(metadata)
269    }
270
271    #[tracing::instrument(name = "metadata_cache.refresh_all", skip_all)]
272    async fn refresh_all(&self, client: &reqwest::Client) {
273        let keys: Vec<String> = {
275            let cache = self.cache.read().await;
276            cache.keys().cloned().collect()
277        };
278
279        for issuer in keys {
280            if let Err(e) = self.fetch(client, &issuer, true).await {
281                tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
282            }
283        }
284
285        let keys: Vec<String> = {
287            let cache = self.insecure_cache.read().await;
288            cache.keys().cloned().collect()
289        };
290
291        for issuer in keys {
292            if let Err(e) = self.fetch(client, &issuer, false).await {
293                tracing::error!(issuer = %issuer, error = &e as &dyn std::error::Error, "Failed to refresh provider metadata");
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use mas_data_model::{
305        Clock, UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
306        UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
307    };
308    use mas_iana::jose::JsonWebSignatureAlg;
309    use oauth2_types::scope::{OPENID, Scope};
310    use ulid::Ulid;
311    use wiremock::{
312        Mock, MockServer, ResponseTemplate,
313        matchers::{method, path},
314    };
315
316    use super::*;
317    use crate::test_utils::setup;
318
319    #[tokio::test]
320    async fn test_metadata_cache() {
321        setup();
322        let mock_server = MockServer::start().await;
323        let http_client = mas_http::reqwest_client();
324
325        let cache = MetadataCache::new();
326
327        cache
329            .get(&http_client, &mock_server.uri(), false)
330            .await
331            .unwrap_err();
332
333        let expected_calls = 3;
334        let mut calls = 0;
335        let _mock_guard = Mock::given(method("GET"))
336            .and(path("/.well-known/openid-configuration"))
337            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
338                "issuer": mock_server.uri(),
339                "authorization_endpoint": "https://example.com/authorize",
340                "token_endpoint": "https://example.com/token",
341                "jwks_uri": "https://example.com/jwks",
342                "userinfo_endpoint": "https://example.com/userinfo",
343                "scopes_supported": ["openid"],
344                "response_types_supported": ["code"],
345                "response_modes_supported": ["query", "fragment"],
346                "grant_types_supported": ["authorization_code"],
347                "subject_types_supported": ["public"],
348                "id_token_signing_alg_values_supported": ["RS256"],
349            })))
350            .expect(expected_calls)
351            .mount(&mock_server)
352            .await;
353
354        cache
356            .get(&http_client, &mock_server.uri(), false)
357            .await
358            .unwrap();
359        calls += 1;
360
361        cache
363            .get(&http_client, &mock_server.uri(), false)
364            .await
365            .unwrap();
366        calls += 0;
367
368        cache
370            .get(&http_client, &mock_server.uri(), true)
371            .await
372            .unwrap_err();
373        calls += 1;
374
375        cache.refresh_all(&http_client).await;
377        calls += 1;
378
379        assert_eq!(calls, expected_calls);
380    }
381
382    #[tokio::test]
383    async fn test_lazy_provider_infos() {
384        setup();
385
386        let mock_server = MockServer::start().await;
387        let http_client = mas_http::reqwest_client();
388
389        let expected_calls = 2;
390        let mut calls = 0;
391        let _mock_guard = Mock::given(method("GET"))
392            .and(path("/.well-known/openid-configuration"))
393            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
394                "issuer": mock_server.uri(),
395                "authorization_endpoint": "https://example.com/authorize",
396                "token_endpoint": "https://example.com/token",
397                "jwks_uri": "https://example.com/jwks",
398                "userinfo_endpoint": "https://example.com/userinfo",
399                "scopes_supported": ["openid"],
400                "response_types_supported": ["code"],
401                "response_modes_supported": ["query", "fragment"],
402                "grant_types_supported": ["authorization_code"],
403                "subject_types_supported": ["public"],
404                "id_token_signing_alg_values_supported": ["RS256"],
405            })))
406            .expect(expected_calls)
407            .mount(&mock_server)
408            .await;
409
410        let clock = MockClock::default();
411        let provider = UpstreamOAuthProvider {
412            id: Ulid::nil(),
413            issuer: Some(mock_server.uri()),
414            human_name: Some("Example Ltd.".to_owned()),
415            brand_name: None,
416            discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
417            pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
418            fetch_userinfo: false,
419            userinfo_signed_response_alg: None,
420            jwks_uri_override: None,
421            authorization_endpoint_override: None,
422            scope: Scope::from_iter([OPENID]),
423            userinfo_endpoint_override: None,
424            token_endpoint_override: None,
425            client_id: "client_id".to_owned(),
426            encrypted_client_secret: None,
427            token_endpoint_signing_alg: None,
428            token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
429            id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
430            response_mode: None,
431            created_at: clock.now(),
432            disabled_at: None,
433            claims_imports: UpstreamOAuthProviderClaimsImports::default(),
434            additional_authorization_parameters: Vec::new(),
435            forward_login_hint: false,
436            on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
437        };
438
439        {
441            let cache = MetadataCache::new();
442            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
443            lazy_metadata.maybe_discover().await.unwrap();
444            assert_eq!(
445                lazy_metadata
446                    .authorization_endpoint()
447                    .await
448                    .unwrap()
449                    .as_str(),
450                "https://example.com/authorize"
451            );
452            calls += 1;
453        }
454
455        {
457            let provider = UpstreamOAuthProvider {
458                jwks_uri_override: Some("https://example.com/jwks_override".parse().unwrap()),
459                authorization_endpoint_override: Some(
460                    "https://example.com/authorize_override".parse().unwrap(),
461                ),
462                token_endpoint_override: Some(
463                    "https://example.com/token_override".parse().unwrap(),
464                ),
465                ..provider.clone()
466            };
467            let cache = MetadataCache::new();
468            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
469            assert_eq!(
470                lazy_metadata.jwks_uri().await.unwrap().as_str(),
471                "https://example.com/jwks_override"
472            );
473            assert_eq!(
474                lazy_metadata
475                    .authorization_endpoint()
476                    .await
477                    .unwrap()
478                    .as_str(),
479                "https://example.com/authorize_override"
480            );
481            assert_eq!(
482                lazy_metadata.token_endpoint().await.unwrap().as_str(),
483                "https://example.com/token_override"
484            );
485            calls += 0;
487        }
488
489        {
491            let provider = UpstreamOAuthProvider {
492                discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
493                ..provider.clone()
494            };
495            let cache = MetadataCache::new();
496            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
497            lazy_metadata.authorization_endpoint().await.unwrap_err();
498            calls += 1;
500        }
501
502        {
504            let provider = UpstreamOAuthProvider {
505                discovery_mode: UpstreamOAuthProviderDiscoveryMode::Disabled,
506                authorization_endpoint_override: Some(
507                    Url::parse("https://example.com/authorize_override").unwrap(),
508                ),
509                token_endpoint_override: None,
510                ..provider.clone()
511            };
512            let cache = MetadataCache::new();
513            let mut lazy_metadata = LazyProviderInfos::new(&cache, &provider, &http_client);
514            assert!(lazy_metadata.maybe_discover().await.unwrap().is_none());
516            assert_eq!(
517                lazy_metadata
518                    .authorization_endpoint()
519                    .await
520                    .unwrap()
521                    .as_str(),
522                "https://example.com/authorize_override"
523            );
524            assert!(matches!(
525                lazy_metadata.token_endpoint().await,
526                Err(DiscoveryError::Disabled),
527            ));
528            calls += 0;
530        }
531
532        assert_eq!(calls, expected_calls);
533    }
534}