1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 Clock, User,
12 personal::{
13 PersonalAccessToken,
14 session::{PersonalSession, PersonalSessionOwner, SessionState},
15 },
16};
17use mas_storage::{
18 Page, Pagination,
19 pagination::Node,
20 personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
21};
22use oauth2_types::scope::Scope;
23use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26 Cond, Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
27 extension::postgres::PgExpr as _,
28};
29use sea_query_binder::SqlxBinder as _;
30use sqlx::PgConnection;
31use tracing::{Instrument as _, info_span};
32use ulid::Ulid;
33use uuid::Uuid;
34
35use crate::{
36 DatabaseError,
37 errors::DatabaseInconsistencyError,
38 filter::{Filter, StatementExt as _},
39 iden::{PersonalAccessTokens, PersonalSessions},
40 pagination::QueryBuilderExt as _,
41 tracing::ExecuteExt as _,
42};
43
44pub struct PgPersonalSessionRepository<'c> {
47 conn: &'c mut PgConnection,
48}
49
50impl<'c> PgPersonalSessionRepository<'c> {
51 pub fn new(conn: &'c mut PgConnection) -> Self {
54 Self { conn }
55 }
56}
57
58#[derive(sqlx::FromRow)]
59#[enum_def]
60struct PersonalSessionLookup {
61 personal_session_id: Uuid,
62 owner_user_id: Option<Uuid>,
63 owner_oauth2_client_id: Option<Uuid>,
64 actor_user_id: Uuid,
65 human_name: String,
66 scope_list: Vec<String>,
67 created_at: DateTime<Utc>,
68 revoked_at: Option<DateTime<Utc>>,
69 last_active_at: Option<DateTime<Utc>>,
70 last_active_ip: Option<IpAddr>,
71}
72
73impl Node<Ulid> for PersonalSessionLookup {
74 fn cursor(&self) -> Ulid {
75 self.personal_session_id.into()
76 }
77}
78
79impl TryFrom<PersonalSessionLookup> for PersonalSession {
80 type Error = DatabaseInconsistencyError;
81
82 fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
83 let id = Ulid::from(value.personal_session_id);
84 let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
85 let scope = scope.map_err(|e| {
86 DatabaseInconsistencyError::on("personal_sessions")
87 .column("scope")
88 .row(id)
89 .source(e)
90 })?;
91
92 let state = match value.revoked_at {
93 None => SessionState::Valid,
94 Some(revoked_at) => SessionState::Revoked { revoked_at },
95 };
96
97 let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
98 (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
99 (None, Some(owner_oauth2_client_id)) => {
100 PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
101 }
102 _ => {
103 return Err(DatabaseInconsistencyError::on("personal_sessions")
105 .column("owner_user_id, owner_oauth2_client_id")
106 .row(id));
107 }
108 };
109
110 Ok(PersonalSession {
111 id,
112 state,
113 owner,
114 actor_user_id: Ulid::from(value.actor_user_id),
115 human_name: value.human_name,
116 scope,
117 created_at: value.created_at,
118 last_active_at: value.last_active_at,
119 last_active_ip: value.last_active_ip,
120 })
121 }
122}
123
124#[derive(sqlx::FromRow)]
125#[enum_def]
126struct PersonalSessionAndAccessTokenLookup {
127 personal_session_id: Uuid,
128 owner_user_id: Option<Uuid>,
129 owner_oauth2_client_id: Option<Uuid>,
130 actor_user_id: Uuid,
131 human_name: String,
132 scope_list: Vec<String>,
133 created_at: DateTime<Utc>,
134 revoked_at: Option<DateTime<Utc>>,
135 last_active_at: Option<DateTime<Utc>>,
136 last_active_ip: Option<IpAddr>,
137
138 personal_access_token_id: Option<Uuid>,
140 token_created_at: Option<DateTime<Utc>>,
141 token_expires_at: Option<DateTime<Utc>>,
142}
143
144impl Node<Ulid> for PersonalSessionAndAccessTokenLookup {
145 fn cursor(&self) -> Ulid {
146 self.personal_session_id.into()
147 }
148}
149
150impl TryFrom<PersonalSessionAndAccessTokenLookup>
151 for (PersonalSession, Option<PersonalAccessToken>)
152{
153 type Error = DatabaseInconsistencyError;
154
155 fn try_from(value: PersonalSessionAndAccessTokenLookup) -> Result<Self, Self::Error> {
156 let session = PersonalSession::try_from(PersonalSessionLookup {
157 personal_session_id: value.personal_session_id,
158 owner_user_id: value.owner_user_id,
159 owner_oauth2_client_id: value.owner_oauth2_client_id,
160 actor_user_id: value.actor_user_id,
161 human_name: value.human_name,
162 scope_list: value.scope_list,
163 created_at: value.created_at,
164 revoked_at: value.revoked_at,
165 last_active_at: value.last_active_at,
166 last_active_ip: value.last_active_ip,
167 })?;
168
169 let token_opt = if let Some(id) = value.personal_access_token_id {
170 let id = Ulid::from(id);
171 Some(PersonalAccessToken {
172 id,
173 session_id: session.id,
174 created_at: value.token_created_at.ok_or(
176 DatabaseInconsistencyError::on("personal_sessions")
177 .column("created_at")
178 .row(id),
179 )?,
180 expires_at: value.token_expires_at,
181 revoked_at: None,
182 })
183 } else {
184 None
185 };
186
187 Ok((session, token_opt))
188 }
189}
190
191#[async_trait]
192impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
193 type Error = DatabaseError;
194
195 #[tracing::instrument(
196 name = "db.personal_session.lookup",
197 skip_all,
198 fields(
199 db.query.text,
200 session.id = %id,
201 ),
202 err,
203 )]
204 async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
205 let res = sqlx::query_as!(
206 PersonalSessionLookup,
207 r#"
208 SELECT personal_session_id
209 , owner_user_id
210 , owner_oauth2_client_id
211 , actor_user_id
212 , scope_list
213 , created_at
214 , revoked_at
215 , human_name
216 , last_active_at
217 , last_active_ip as "last_active_ip: IpAddr"
218 FROM personal_sessions
219
220 WHERE personal_session_id = $1
221 "#,
222 Uuid::from(id),
223 )
224 .traced()
225 .fetch_optional(&mut *self.conn)
226 .await?;
227
228 let Some(session) = res else { return Ok(None) };
229
230 Ok(Some(session.try_into()?))
231 }
232
233 #[tracing::instrument(
234 name = "db.personal_session.add",
235 skip_all,
236 fields(
237 db.query.text,
238 session.id,
239 session.scope = %scope,
240 ),
241 err,
242 )]
243 async fn add(
244 &mut self,
245 rng: &mut (dyn RngCore + Send),
246 clock: &dyn Clock,
247 owner: PersonalSessionOwner,
248 actor_user: &User,
249 human_name: String,
250 scope: Scope,
251 ) -> Result<PersonalSession, Self::Error> {
252 let created_at = clock.now();
253 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
254 tracing::Span::current().record("session.id", tracing::field::display(id));
255
256 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
257
258 let (owner_user_id, owner_oauth2_client_id) = match owner {
259 PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
260 PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
261 };
262
263 sqlx::query!(
264 r#"
265 INSERT INTO personal_sessions
266 ( personal_session_id
267 , owner_user_id
268 , owner_oauth2_client_id
269 , actor_user_id
270 , human_name
271 , scope_list
272 , created_at
273 )
274 VALUES ($1, $2, $3, $4, $5, $6, $7)
275 "#,
276 Uuid::from(id),
277 owner_user_id,
278 owner_oauth2_client_id,
279 Uuid::from(actor_user.id),
280 &human_name,
281 &scope_list,
282 created_at,
283 )
284 .traced()
285 .execute(&mut *self.conn)
286 .await?;
287
288 Ok(PersonalSession {
289 id,
290 state: SessionState::Valid,
291 owner,
292 actor_user_id: actor_user.id,
293 human_name,
294 scope,
295 created_at,
296 last_active_at: None,
297 last_active_ip: None,
298 })
299 }
300
301 #[tracing::instrument(
302 name = "db.personal_session.revoke",
303 skip_all,
304 fields(
305 db.query.text,
306 %session.id,
307 %session.scope,
308 ),
309 err,
310 )]
311 async fn revoke(
312 &mut self,
313 clock: &dyn Clock,
314 session: PersonalSession,
315 ) -> Result<PersonalSession, Self::Error> {
316 let revoked_at = clock.now();
317
318 {
319 let span = info_span!(
321 "db.personal_session.revoke.tokens",
322 { DB_QUERY_TEXT } = tracing::field::Empty,
323 );
324
325 sqlx::query!(
326 r#"
327 UPDATE personal_access_tokens
328 SET revoked_at = $2
329 WHERE personal_session_id = $1 AND revoked_at IS NULL
330 "#,
331 Uuid::from(session.id),
332 revoked_at,
333 )
334 .record(&span)
335 .execute(&mut *self.conn)
336 .instrument(span)
337 .await?;
338 }
339
340 let res = sqlx::query!(
341 r#"
342 UPDATE personal_sessions
343 SET revoked_at = $2
344 WHERE personal_session_id = $1
345 "#,
346 Uuid::from(session.id),
347 revoked_at,
348 )
349 .traced()
350 .execute(&mut *self.conn)
351 .await?;
352
353 DatabaseError::ensure_affected_rows(&res, 1)?;
354
355 session
356 .finish(revoked_at)
357 .map_err(DatabaseError::to_invalid_operation)
358 }
359
360 #[tracing::instrument(
361 name = "db.personal_session.revoke_bulk",
362 skip_all,
363 fields(
364 db.query.text,
365 ),
366 err,
367 )]
368 async fn revoke_bulk(
369 &mut self,
370 clock: &dyn Clock,
371 filter: PersonalSessionFilter<'_>,
372 ) -> Result<usize, Self::Error> {
373 let revoked_at = clock.now();
374
375 let (sql, arguments) = Query::update()
376 .table(PersonalSessions::Table)
377 .value(PersonalSessions::RevokedAt, revoked_at)
378 .and_where(
379 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
380 .in_subquery(
384 Query::select()
385 .expr(Expr::col((
386 PersonalSessions::Table,
387 PersonalSessions::PersonalSessionId,
388 )))
389 .from(PersonalSessions::Table)
390 .left_join(
391 PersonalAccessTokens::Table,
392 Cond::all()
393 .add(
395 Expr::col((
396 PersonalSessions::Table,
397 PersonalSessions::PersonalSessionId,
398 ))
399 .eq(Expr::col((
400 PersonalAccessTokens::Table,
401 PersonalAccessTokens::PersonalSessionId,
402 ))),
403 )
404 .add(
406 Expr::col((
407 PersonalAccessTokens::Table,
408 PersonalAccessTokens::RevokedAt,
409 ))
410 .is_null(),
411 ),
412 )
413 .apply_filter(filter)
414 .take(),
415 ),
416 )
417 .build_sqlx(PostgresQueryBuilder);
418
419 let res = sqlx::query_with(&sql, arguments)
420 .traced()
421 .execute(&mut *self.conn)
422 .await?;
423
424 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
425 }
426
427 #[tracing::instrument(
428 name = "db.personal_session.list",
429 skip_all,
430 fields(
431 db.query.text,
432 ),
433 err,
434 )]
435 async fn list(
436 &mut self,
437 filter: PersonalSessionFilter<'_>,
438 pagination: Pagination,
439 ) -> Result<Page<(PersonalSession, Option<PersonalAccessToken>)>, Self::Error> {
440 let (sql, arguments) = Query::select()
441 .expr_as(
442 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
443 PersonalSessionAndAccessTokenLookupIden::PersonalSessionId,
444 )
445 .expr_as(
446 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
447 PersonalSessionAndAccessTokenLookupIden::OwnerUserId,
448 )
449 .expr_as(
450 Expr::col((
451 PersonalSessions::Table,
452 PersonalSessions::OwnerOAuth2ClientId,
453 )),
454 PersonalSessionAndAccessTokenLookupIden::OwnerOauth2ClientId,
455 )
456 .expr_as(
457 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
458 PersonalSessionAndAccessTokenLookupIden::ActorUserId,
459 )
460 .expr_as(
461 Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
462 PersonalSessionAndAccessTokenLookupIden::HumanName,
463 )
464 .expr_as(
465 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
466 PersonalSessionAndAccessTokenLookupIden::ScopeList,
467 )
468 .expr_as(
469 Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
470 PersonalSessionAndAccessTokenLookupIden::CreatedAt,
471 )
472 .expr_as(
473 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
474 PersonalSessionAndAccessTokenLookupIden::RevokedAt,
475 )
476 .expr_as(
477 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
478 PersonalSessionAndAccessTokenLookupIden::LastActiveAt,
479 )
480 .expr_as(
481 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
482 PersonalSessionAndAccessTokenLookupIden::LastActiveIp,
483 )
484 .expr_as(
485 Expr::col((
486 PersonalAccessTokens::Table,
487 PersonalAccessTokens::PersonalAccessTokenId,
488 )),
489 PersonalSessionAndAccessTokenLookupIden::PersonalAccessTokenId,
490 )
491 .expr_as(
492 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::CreatedAt)),
493 PersonalSessionAndAccessTokenLookupIden::TokenCreatedAt,
494 )
495 .expr_as(
496 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt)),
497 PersonalSessionAndAccessTokenLookupIden::TokenExpiresAt,
498 )
499 .from(PersonalSessions::Table)
500 .left_join(
501 PersonalAccessTokens::Table,
502 Cond::all()
503 .add(
505 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
506 .eq(Expr::col((
507 PersonalAccessTokens::Table,
508 PersonalAccessTokens::PersonalSessionId,
509 ))),
510 )
511 .add(
513 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
514 .is_null(),
515 ),
516 )
517 .apply_filter(filter)
518 .generate_pagination(
519 (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
520 pagination,
521 )
522 .build_sqlx(PostgresQueryBuilder);
523
524 let edges: Vec<PersonalSessionAndAccessTokenLookup> = sqlx::query_as_with(&sql, arguments)
525 .traced()
526 .fetch_all(&mut *self.conn)
527 .await?;
528
529 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
530
531 Ok(page)
532 }
533
534 #[tracing::instrument(
535 name = "db.personal_session.count",
536 skip_all,
537 fields(
538 db.query.text,
539 ),
540 err,
541 )]
542 async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
543 let (sql, arguments) = Query::select()
544 .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
545 .from(PersonalSessions::Table)
546 .left_join(
547 PersonalAccessTokens::Table,
548 Cond::all()
549 .add(
551 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
552 .eq(Expr::col((
553 PersonalAccessTokens::Table,
554 PersonalAccessTokens::PersonalSessionId,
555 ))),
556 )
557 .add(
559 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
560 .is_null(),
561 ),
562 )
563 .apply_filter(filter)
564 .build_sqlx(PostgresQueryBuilder);
565
566 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
567 .traced()
568 .fetch_one(&mut *self.conn)
569 .await?;
570
571 count
572 .try_into()
573 .map_err(DatabaseError::to_invalid_operation)
574 }
575
576 #[tracing::instrument(
577 name = "db.personal_session.record_batch_activity",
578 skip_all,
579 fields(
580 db.query.text,
581 ),
582 err,
583 )]
584 async fn record_batch_activity(
585 &mut self,
586 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
587 ) -> Result<(), Self::Error> {
588 activities.sort_unstable();
591 let mut ids = Vec::with_capacity(activities.len());
592 let mut last_activities = Vec::with_capacity(activities.len());
593 let mut ips = Vec::with_capacity(activities.len());
594
595 for (id, last_activity, ip) in activities {
596 ids.push(Uuid::from(id));
597 last_activities.push(last_activity);
598 ips.push(ip);
599 }
600
601 let res = sqlx::query!(
602 r#"
603 UPDATE personal_sessions
604 SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at)
605 , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip)
606 FROM (
607 SELECT *
608 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
609 AS t(personal_session_id, last_active_at, last_active_ip)
610 ) AS t
611 WHERE personal_sessions.personal_session_id = t.personal_session_id
612 "#,
613 &ids,
614 &last_activities,
615 &ips as &[Option<IpAddr>],
616 )
617 .traced()
618 .execute(&mut *self.conn)
619 .await?;
620
621 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
622
623 Ok(())
624 }
625}
626
627impl Filter for PersonalSessionFilter<'_> {
628 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
629 sea_query::Condition::all()
630 .add_option(self.owner_user().map(|user| {
631 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
632 .eq(Uuid::from(user.id))
633 }))
634 .add_option(self.owner_oauth2_client().map(|client| {
635 Expr::col((
636 PersonalSessions::Table,
637 PersonalSessions::OwnerOAuth2ClientId,
638 ))
639 .eq(Uuid::from(client.id))
640 }))
641 .add_option(self.actor_user().map(|user| {
642 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
643 .eq(Uuid::from(user.id))
644 }))
645 .add_option(self.device().map(|device| -> SimpleExpr {
646 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
647 Condition::any()
648 .add(
649 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
650 PersonalSessions::Table,
651 PersonalSessions::ScopeList,
652 )))),
653 )
654 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
655 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
656 )))
657 .into()
658 } else {
659 Expr::val(false).into()
661 }
662 }))
663 .add_option(self.state().map(|state| match state {
664 PersonalSessionState::Active => {
665 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
666 }
667 PersonalSessionState::Revoked => {
668 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
669 }
670 }))
671 .add_option(self.scope().map(|scope| {
672 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
673 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
674 }))
675 .add_option(self.last_active_before().map(|last_active_before| {
676 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
677 .lt(last_active_before)
678 }))
679 .add_option(self.last_active_after().map(|last_active_after| {
680 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
681 .gt(last_active_after)
682 }))
683 .add_option(self.expires_before().map(|expires_before| {
684 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
685 .lt(expires_before)
686 }))
687 .add_option(self.expires_after().map(|expires_after| {
688 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
689 .gt(expires_after)
690 }))
691 .add_option(self.expires().map(|expires| {
692 let column =
693 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt));
694
695 if expires {
696 column.is_not_null()
697 } else {
698 column.is_null()
699 }
700 }))
701 }
702}