1use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use mas_data_model::{
9 Clock,
10 personal::{PersonalAccessToken, session::PersonalSession},
11};
12use mas_storage::personal::PersonalAccessTokenRepository;
13use rand::RngCore;
14use sha2::{Digest, Sha256};
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, tracing::ExecuteExt as _};
20
21pub struct PgPersonalAccessTokenRepository<'c> {
24 conn: &'c mut PgConnection,
25}
26
27impl<'c> PgPersonalAccessTokenRepository<'c> {
28 pub fn new(conn: &'c mut PgConnection) -> Self {
31 Self { conn }
32 }
33}
34
35struct PersonalAccessTokenLookup {
36 personal_access_token_id: Uuid,
37 personal_session_id: Uuid,
38 created_at: DateTime<Utc>,
39 expires_at: Option<DateTime<Utc>>,
40 revoked_at: Option<DateTime<Utc>>,
41}
42
43impl From<PersonalAccessTokenLookup> for PersonalAccessToken {
44 fn from(value: PersonalAccessTokenLookup) -> Self {
45 Self {
46 id: Ulid::from(value.personal_access_token_id),
47 session_id: Ulid::from(value.personal_session_id),
48 created_at: value.created_at,
49 expires_at: value.expires_at,
50 revoked_at: value.revoked_at,
51 }
52 }
53}
54
55#[async_trait]
56impl PersonalAccessTokenRepository for PgPersonalAccessTokenRepository<'_> {
57 type Error = DatabaseError;
58
59 #[tracing::instrument(
60 name = "db.personal_access_token.lookup",
61 skip_all,
62 fields(
63 db.query.text,
64 personal_access_token.id = %id,
65 ),
66 err,
67 )]
68 async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalAccessToken>, Self::Error> {
69 let res = sqlx::query_as!(
70 PersonalAccessTokenLookup,
71 r#"
72 SELECT personal_access_token_id
73 , personal_session_id
74 , created_at
75 , expires_at
76 , revoked_at
77
78 FROM personal_access_tokens
79
80 WHERE personal_access_token_id = $1
81 "#,
82 Uuid::from(id),
83 )
84 .traced()
85 .fetch_optional(&mut *self.conn)
86 .await?;
87
88 let Some(res) = res else { return Ok(None) };
89
90 Ok(Some(res.into()))
91 }
92
93 #[tracing::instrument(
94 name = "db.personal_access_token.find_by_token",
95 skip_all,
96 fields(
97 db.query.text,
98 ),
99 err,
100 )]
101 async fn find_by_token(
102 &mut self,
103 access_token: &str,
104 ) -> Result<Option<PersonalAccessToken>, Self::Error> {
105 let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
106
107 let res = sqlx::query_as!(
108 PersonalAccessTokenLookup,
109 r#"
110 SELECT personal_access_token_id
111 , personal_session_id
112 , created_at
113 , expires_at
114 , revoked_at
115
116 FROM personal_access_tokens
117
118 WHERE access_token_sha256 = $1
119 "#,
120 &token_sha256,
121 )
122 .traced()
123 .fetch_optional(&mut *self.conn)
124 .await?;
125
126 let Some(res) = res else { return Ok(None) };
127
128 Ok(Some(res.into()))
129 }
130
131 #[tracing::instrument(
132 name = "db.personal_access_token.add",
133 skip_all,
134 fields(
135 db.query.text,
136 personal_access_token.id,
137 %session.id,
138 ),
139 err,
140 )]
141 async fn add(
142 &mut self,
143 rng: &mut (dyn RngCore + Send),
144 clock: &dyn Clock,
145 session: &PersonalSession,
146 access_token: &str,
147 expires_after: Option<chrono::Duration>,
148 ) -> Result<PersonalAccessToken, Self::Error> {
149 let created_at = clock.now();
150 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
151 tracing::Span::current().record("personal_access_token.id", tracing::field::display(id));
152
153 let token_sha256 = Sha256::digest(access_token.as_bytes()).to_vec();
154
155 let expires_at = expires_after.map(|expires_after| created_at + expires_after);
156
157 sqlx::query!(
158 r#"
159 INSERT INTO personal_access_tokens
160 (personal_access_token_id, personal_session_id, access_token_sha256, created_at, expires_at)
161 VALUES ($1, $2, $3, $4, $5)
162 "#,
163 Uuid::from(id),
164 Uuid::from(session.id),
165 &token_sha256,
166 created_at,
167 expires_at,
168 )
169 .traced()
170 .execute(&mut *self.conn)
171 .await?;
172
173 Ok(PersonalAccessToken {
174 id,
175 session_id: session.id,
176 created_at,
177 expires_at,
178 revoked_at: None,
179 })
180 }
181
182 #[tracing::instrument(
183 name = "db.personal_access_token.revoke",
184 skip_all,
185 fields(
186 db.query.text,
187 %access_token.id,
188 personal_session.id = %access_token.session_id,
189 ),
190 err,
191 )]
192 async fn revoke(
193 &mut self,
194 clock: &dyn Clock,
195 mut access_token: PersonalAccessToken,
196 ) -> Result<PersonalAccessToken, Self::Error> {
197 let revoked_at = clock.now();
198 let res = sqlx::query!(
199 r#"
200 UPDATE personal_access_tokens
201 SET revoked_at = $2
202 WHERE personal_access_token_id = $1
203 "#,
204 Uuid::from(access_token.id),
205 revoked_at,
206 )
207 .traced()
208 .execute(&mut *self.conn)
209 .await?;
210
211 DatabaseError::ensure_affected_rows(&res, 1)?;
212
213 access_token.revoked_at = Some(revoked_at);
214 Ok(access_token)
215 }
216}