Skip to content

Commit 7de710b

Browse files
committed
fix
1 parent bb16588 commit 7de710b

File tree

9 files changed

+59
-212
lines changed

9 files changed

+59
-212
lines changed

src/admin/query/sync.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub(super) async fn show_connection(
4949
// Admin CLI/backend calls don't have an X-Forwarded-For header context,
5050
// so pass None for the forwarded-for field.
5151
let key = into_connection_key(user_id, device_id, conn_id, None::<String>);
52-
let cache = self.services.sync.find_connection(&key)?;
52+
let cache = self.services.sync.find_connection(&key).await?;
5353

5454
let out;
5555
{

src/api/client/read_marker.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ pub(crate) async fn set_read_marker_route(
2929
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
3030
services
3131
.pusher
32-
.reset_notification_counts(sender_user, &body.room_id);
32+
.reset_notification_counts(sender_user, &body.room_id)
33+
.await;
3334
}
3435

3536
if let Some(event) = &body.fully_read {
@@ -120,7 +121,8 @@ pub(crate) async fn create_receipt_route(
120121
) {
121122
services
122123
.pusher
123-
.reset_notification_counts(sender_user, &body.room_id);
124+
.reset_notification_counts(sender_user, &body.room_id)
125+
.await;
124126
}
125127

126128
match body.receipt_type {

src/api/client/session/oauth.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ fn consume_redirect(state_token: &str) -> Option<String> {
5454

5555
/// OAuth 2.0 token response
5656
#[derive(Clone, Debug, Serialize, Deserialize)]
57-
pub(super) struct OAuthTokenResponse {
57+
pub(crate) struct OAuthTokenResponse {
5858
pub access_token: String,
5959
pub token_type: String,
6060
#[serde(skip_serializing_if = "Option::is_none")]
@@ -69,7 +69,7 @@ pub(super) struct OAuthTokenResponse {
6969

7070
/// OAuth 2.0 userinfo response
7171
#[derive(Clone, Debug, Serialize, Deserialize)]
72-
pub(super) struct OAuthUserInfo {
72+
pub(crate) struct OAuthUserInfo {
7373
pub sub: String,
7474
#[serde(skip_serializing_if = "Option::is_none")]
7575
pub name: Option<String>,
@@ -326,7 +326,7 @@ pub(crate) async fn oauth_callback_route(
326326
use ruma::UserId;
327327
let user_id_ref: &UserId = user_id.as_ref();
328328
let token = utils::random_string(TOKEN_LENGTH);
329-
services
329+
let _ = services
330330
.users
331331
.create_login_token(user_id_ref, &token);
332332
token

src/api/client/sync/v5.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ pub(crate) async fn sync_events_v5_route(
8080
InsecureClientIp(client): InsecureClientIp,
8181
body: Ruma<Request>,
8282
) -> Result<Response> {
83-
let (sender_user, sender_device) = body.sender();
83+
let (_sender_user, _sender_device) = body.sender();
8484
let request = &body.body;
8585
let since = request
8686
.pos
@@ -113,11 +113,12 @@ pub(crate) async fn sync_events_v5_route(
113113
request.conn_id.as_deref(),
114114
x_forwarded_for,
115115
);
116-
let conn_val = since
117-
.ne(&0)
118-
.then(|| services.sync.find_connection(&conn_key))
119-
.unwrap_or_else(|| Ok(services.sync.init_connection(&conn_key)))
120-
.map_err(|_| err!(Request(UnknownPos("Connection lost; restarting sync stream."))))?;
116+
let conn_val = if since.ne(&0) {
117+
services.sync.find_connection(&conn_key).boxed()
118+
} else {
119+
async { Ok(services.sync.init_connection(&conn_key).await) }.boxed()
120+
}
121+
.await?;
121122

122123
let conn = conn_val.lock();
123124
let ping_presence = services

src/service/pusher/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,25 @@ impl crate::Service for Service {
6161
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
6262
}
6363

64+
#[implement(Service)]
65+
#[tracing::instrument(level = "debug", skip(self))]
66+
pub async fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
67+
let count = self.services.globals.next_count();
68+
69+
let userroom_id = (user_id, room_id);
70+
self.db
71+
.userroomid_highlightcount
72+
.put(userroom_id, 0_u64);
73+
self.db
74+
.userroomid_notificationcount
75+
.put(userroom_id, 0_u64);
76+
77+
let roomuser_id = (room_id, user_id);
78+
self.db
79+
.roomuserid_lastnotificationread
80+
.put(roomuser_id, *count);
81+
}
82+
6483
#[implement(Service)]
6584
pub async fn set_pusher(
6685
&self,

src/service/pusher/notification.rs

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,9 @@ use tuwunel_core::{
55
};
66
use tuwunel_database::{Deserialized, Interfix};
77

8-
pub struct Service {
9-
db: Data,
10-
services: Arc<crate::services::OnceServices>,
11-
}
12-
13-
struct Data {
14-
userroomid_notificationcount: Arc<Map>,
15-
userroomid_highlightcount: Arc<Map>,
16-
roomuserid_lastnotificationread: Arc<Map>,
17-
}
188

19-
impl crate::Service for Service {
20-
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
21-
Ok(Arc::new(Self {
22-
db: Data {
23-
userroomid_notificationcount: args.db["userroomid_notificationcount"].clone(),
24-
userroomid_highlightcount: args.db["userroomid_highlightcount"].clone(),
25-
roomuserid_lastnotificationread: args.db["roomuserid_lastnotificationread"]
26-
.clone(),
27-
},
28-
services: args.services.clone(),
29-
}))
30-
}
319

32-
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
33-
}
3410

35-
#[implement(Service)]
36-
#[tracing::instrument(level = "debug", skip(self))]
37-
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
38-
let count = self.services.globals.next_count();
39-
40-
let userroom_id = (user_id, room_id);
41-
self.db
42-
.userroomid_highlightcount
43-
.put(userroom_id, 0_u64);
44-
self.db
45-
.userroomid_notificationcount
46-
.put(userroom_id, 0_u64);
47-
48-
let roomuser_id = (room_id, user_id);
49-
self.db
50-
.roomuserid_lastnotificationread
51-
.put(roomuser_id, *count);
52-
}
5311

5412
#[implement(super::Service)]
5513
#[tracing::instrument(level = "debug", skip(self), ret(level = "trace"))]

src/service/rooms/timeline/append.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ where
169169

170170
self.services
171171
.pusher
172-
.reset_notification_counts(pdu.sender(), pdu.room_id());
172+
.reset_notification_counts(pdu.sender(), pdu.room_id())
173+
.await;
173174

174175
let count = PduCount::Normal(*next_count1);
175176
let pdu_id: RawPduId = PduId { shortroomid, count }.into();

src/service/sync/mod.rs

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ pub async fn clear_connections(
111111
self.connections
112112
.lock()
113113
.await
114-
.retain(|(conn_user_id, conn_device_id, conn_conn_id), _| {
114+
.retain(|(conn_user_id, conn_device_id, conn_conn_id, conn_xff), _| {
115115
let retain = user_id.is_none_or(is_equal_to!(conn_user_id))
116116
&& device_id.is_none_or(is_equal_to!(conn_device_id))
117117
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref());
118118

119119
if !retain {
120120
self.db
121121
.userdeviceconnid_conn
122-
.del((conn_user_id, conn_device_id, conn_conn_id));
122+
.del((conn_user_id, conn_device_id, conn_conn_id, conn_xff));
123123
}
124124

125125
retain
@@ -129,10 +129,10 @@ pub async fn clear_connections(
129129
#[implement(Service)]
130130
#[tracing::instrument(level = "debug", skip(self))]
131131
pub async fn drop_connection(&self, key: &ConnectionKey) {
132-
let mut cache = self.connections.lock().await;
133-
134-
self.db.userdeviceconnid_conn.del(key);
135-
cache.remove(key);
132+
self.connections
133+
.lock()
134+
.await
135+
.remove(key);
136136
}
137137

138138
#[implement(Service)]
@@ -220,15 +220,7 @@ pub async fn is_connection_stored(&self, key: &ConnectionKey) -> bool {
220220
self.db.userdeviceconnid_conn.contains(key).await
221221
}
222222

223-
#[inline]
224-
pub fn into_connection_key<U, D, C>(user_id: U, device_id: D, conn_id: Option<C>) -> ConnectionKey
225-
where
226-
U: Into<OwnedUserId>,
227-
D: Into<OwnedDeviceId>,
228-
C: Into<ConnectionId>,
229-
{
230-
(user_id.into(), device_id.into(), conn_id.map(Into::into))
231-
}
223+
232224

233225
#[implement(Connection)]
234226
#[tracing::instrument(level = "debug", skip(self, service))]
@@ -373,66 +365,48 @@ fn some_or_sticky<T: Clone>(target: Option<&T>, cached: &mut Option<T>) {
373365
}
374366
}
375367

376-
#[implement(Service)]
377-
pub fn clear_connections(
378-
&self,
379-
user_id: Option<&UserId>,
380-
device_id: Option<&DeviceId>,
381-
conn_id: Option<&ConnectionId>,
382-
) {
383-
self.connections.lock().expect("locked").retain(
384-
|(conn_user_id, conn_device_id, conn_conn_id, _conn_xff), _| {
385-
!(user_id.is_none_or(is_equal_to!(conn_user_id))
386-
&& device_id.is_none_or(is_equal_to!(conn_device_id))
387-
&& (conn_id.is_none() || conn_id == conn_conn_id.as_ref()))
388-
},
389-
);
390-
}
391368

392-
#[implement(Service)]
393-
pub fn drop_connection(&self, key: &ConnectionKey) {
394-
self.connections
395-
.lock()
396-
.expect("locked")
397-
.remove(key);
398-
}
399369

400370
#[implement(Service)]
401-
pub fn list_connections(&self) -> Vec<ConnectionKey> {
371+
#[tracing::instrument(level = "trace", skip(self))]
372+
pub async fn list_connections(&self) -> Vec<ConnectionKey> {
402373
self.connections
403374
.lock()
404-
.expect("locked")
375+
.await
405376
.keys()
406377
.cloned()
407378
.collect()
408379
}
409380

410381
#[implement(Service)]
411-
pub fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
382+
#[tracing::instrument(level = "debug", skip(self))]
383+
pub async fn init_connection(&self, key: &ConnectionKey) -> ConnectionVal {
412384
self.connections
413385
.lock()
414-
.expect("locked")
386+
.await
415387
.entry(key.clone())
416388
.and_modify(|existing| *existing = ConnectionVal::default())
417389
.or_default()
418390
.clone()
419391
}
420392

421393
#[implement(Service)]
422-
pub fn find_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
394+
#[tracing::instrument(level = "debug", skip(self))]
395+
pub async fn find_connection(&self, key: &ConnectionKey) -> Result<ConnectionVal> {
423396
self.connections
424397
.lock()
425-
.expect("locked")
398+
.await
426399
.get(key)
427400
.cloned()
428401
.ok_or_else(|| err!(Request(NotFound("Connection not found."))))
429402
}
430403

431404
#[implement(Service)]
432-
pub fn contains_connection(&self, key: &ConnectionKey) -> bool {
405+
#[tracing::instrument(level = "trace", skip(self))]
406+
pub async fn contains_connection(&self, key: &ConnectionKey) -> bool {
433407
self.connections
434408
.lock()
435-
.expect("locked")
409+
.await
436410
.contains_key(key)
437411
}
438412

0 commit comments

Comments
 (0)