Skip to content

Commit

Permalink
Merge default-domain feature
Browse files Browse the repository at this point in the history
* Fix default-domain api running on tests
* Fix merge conflicts
* Remove print lines
* Adjust end lines in resource files
  • Loading branch information
spoonman01 committed Oct 8, 2024
2 parents bae2b46 + 00b299e commit 968dc21
Show file tree
Hide file tree
Showing 27 changed files with 380 additions and 136 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@
<artifactId>simpleclient_servlet</artifactId>
<version>${prometheus.version}</version>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.14.1</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/com/wire/bots/hold/DAO/MetadataDAO.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.wire.bots.hold.DAO;

import com.wire.bots.hold.model.Metadata;
import org.jdbi.v3.sqlobject.config.RegisterColumnMapper;
import org.jdbi.v3.sqlobject.customizer.Bind;
import org.jdbi.v3.sqlobject.statement.SqlQuery;
import org.jdbi.v3.sqlobject.statement.SqlUpdate;

public interface MetadataDAO {
String FALLBACK_DOMAIN_KEY = "FALLBACK_DOMAIN_KEY";

@SqlUpdate("INSERT INTO Metadata (key, value)" +
"VALUES (:key, :value)" +
"ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value")
int insert(@Bind("key") String key, @Bind("value") String value);

@SqlQuery("SELECT key, value FROM Metadata WHERE key = :key LIMIT 1")
@RegisterColumnMapper(MetadataResultSetMapper.class)
Metadata get(@Bind("key") String key);
}
19 changes: 19 additions & 0 deletions src/main/java/com/wire/bots/hold/DAO/MetadataResultSetMapper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.wire.bots.hold.DAO;

import com.wire.bots.hold.model.Metadata;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;

import java.sql.ResultSet;
import java.sql.SQLException;

public class MetadataResultSetMapper implements ColumnMapper<Metadata> {

@Override
public Metadata map(ResultSet rs, int columnNumber, StatementContext ctx) throws SQLException {
return new Metadata(
rs.getString("key"),
rs.getString("value")
);
}
}
69 changes: 69 additions & 0 deletions src/main/java/com/wire/bots/hold/FallbackDomainFetcher.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.wire.bots.hold;

import com.wire.bots.hold.DAO.MetadataDAO;
import com.wire.bots.hold.model.Metadata;
import com.wire.bots.hold.utils.Cache;
import com.wire.helium.LoginClient;
import com.wire.helium.models.BackendConfiguration;
import com.wire.xenon.exceptions.HttpException;
import com.wire.xenon.tools.Logger;

import javax.ws.rs.ProcessingException;

public class FallbackDomainFetcher implements Runnable {

private final LoginClient loginClient;
private final MetadataDAO metadataDAO;

/**
* Fetcher and handler for fallback domain.
* <p>
* Fetches from API and compares against database value (if any), then inserts into database and updates cache value.
* If value received from the API is different from what is saved in the database, a [RuntimeException] is thrown.
* </p>
* <p>
* This fallback domain is necessary for LegalHold to work with Federation (as it needs id@domain) and not just the ID anymore.
* In case there is a mismatch we are throwing a RuntimeException so it stops the execution of this app, so in an event
* of already having a defined default domain saved in the database and this app restarts with a different domain
* we don't get mismatching domains.
* </p>
* @param loginClient [{@link LoginClient}] as API to get backend configuration containing default domain.
* @param metadataDAO [{@link MetadataDAO}] as DAO to get/insert default domain to database.
*
* @throws RuntimeException if received domain from API is different from the one saved in the database.
*/
FallbackDomainFetcher(LoginClient loginClient, MetadataDAO metadataDAO) {
this.loginClient = loginClient;
this.metadataDAO = metadataDAO;
}

@Override
public void run() {
if (Cache.getFallbackDomain() != null) { return; }

Metadata metadata = metadataDAO.get(MetadataDAO.FALLBACK_DOMAIN_KEY);
try {
BackendConfiguration apiVersionResponse = loginClient.getBackendConfiguration();

if (metadata == null) {
metadataDAO.insert(MetadataDAO.FALLBACK_DOMAIN_KEY, apiVersionResponse.domain);
Cache.setFallbackDomain(apiVersionResponse.domain);
} else {
if (metadata.value.equals(apiVersionResponse.domain)) {
Cache.setFallbackDomain(apiVersionResponse.domain);
} else {
String formattedExceptionMessage = String.format(
"Database already has a default domain as %s and instead we got %s from the Backend API.",
metadata.value,
apiVersionResponse.domain
);
throw new RuntimeException(formattedExceptionMessage);
}
}
} catch (HttpException exception) {
Logger.exception(exception, "FallbackDomainFetcher.run, exception: %s", exception.getMessage());
} catch (ProcessingException pexception) {
Logger.info("FallbackDomainFetcher.run, ignoring test exceptions");
}
}
}
80 changes: 26 additions & 54 deletions src/main/java/com/wire/bots/hold/NotificationProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,34 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.wire.bots.hold.DAO.AccessDAO;
import com.wire.bots.hold.model.Notification;
import com.wire.bots.hold.model.NotificationList;
import com.wire.bots.hold.model.database.LHAccess;
import com.wire.helium.API;
import com.wire.helium.LoginClient;
import com.wire.helium.models.Access;
import com.wire.helium.models.Event;
import com.wire.helium.models.NotificationList;
import com.wire.xenon.backend.models.Payload;
import com.wire.xenon.exceptions.AuthException;
import com.wire.xenon.exceptions.HttpException;
import com.wire.xenon.tools.Logger;

import javax.ws.rs.client.Client;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Cookie;
import javax.ws.rs.core.HttpHeaders;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import java.util.List;
import java.util.UUID;
import java.util.logging.Level;

public class NotificationProcessor implements Runnable {
private static final int DEFAULT_NOTIFICATION_SIZE = 100;

private final Client client;
private final AccessDAO accessDAO;
private final HoldMessageResource messageResource;
private final WebTarget api;

NotificationProcessor(Client client, AccessDAO accessDAO, Config config, HoldMessageResource messageResource) {
NotificationProcessor(Client client, AccessDAO accessDAO, HoldMessageResource messageResource) {
this.client = client;
this.accessDAO = accessDAO;
this.messageResource = messageResource;

api = client.target(config.apiHost);
}

@Override
Expand All @@ -58,7 +54,10 @@ private Access getAccess(Cookie cookie) throws HttpException {

private void process(LHAccess device) {
UUID userId = device.userId;

try {
final API api = new API(client, null, device.token);

Logger.debug("`GET /notifications`: user: %s, last: %s", userId, device.last);

String cookieValue = device.cookie;
Expand All @@ -76,39 +75,38 @@ private void process(LHAccess device) {

device.token = access.getAccessToken();

NotificationList notificationList = retrieveNotifications(device);
NotificationList notificationList = api.retrieveNotifications(
device.clientId,
device.last,
DEFAULT_NOTIFICATION_SIZE
);

process(userId, notificationList);

} catch (AuthException e) {
accessDAO.disable(userId);
Logger.info("Disabled LH device for user: %s, error: %s", userId, e.getMessage());
} catch (HttpException e) {
Logger.exception(e, "NotificationProcessor: Couldn't retrieve notifications, error: %s", e.getMessage());
} catch (Exception e) {
Logger.exception(e, "NotificationProcessor: user: %s, last: %s, error: %s", userId, device.last, e.getMessage());
}
}

private static String bearer(String token) {
return token == null ? null : String.format("Bearer %s", token);
}

private void process(UUID userId, NotificationList notificationList) {

for (Notification notif : notificationList.notifications) {
for (Payload payload : notif.payload) {
if (!process(userId, payload, notif.id)) {
Logger.error("Failed to process: user: %s, notif: %s", userId, notif.id);
//return;
for (Event event : notificationList.notifications) {
for (Payload payload : event.payload) {
if (!process(userId, payload, event.id)) {
Logger.error("Failed to process: user: %s, event: %s", userId, event.id);
} else {
Logger.debug("Processed: `%s` conv: %s, user: %s, notifId: %s",
Logger.debug("Processed: `%s` conv: %s, user: %s, eventId: %s",
payload.type,
payload.conversation,
userId,
notif.id);
event.id);
}
}

accessDAO.updateLast(userId, notif.id);
accessDAO.updateLast(userId, event.id);
}
}

Expand All @@ -122,13 +120,13 @@ private boolean process(UUID userId, Payload payload, UUID id) {

if (payload.from == null || payload.data == null) return true;

final boolean b = messageResource.onNewMessage(userId, id, payload);
final boolean wasMessageSent = messageResource.onNewMessage(userId, id, payload);

if (!b) {
if (!wasMessageSent) {
Logger.error("process: `%s` user: %s, from: %s:%s, error: %s", payload.type, userId, payload.from, payload.data.sender);
}

return b;
return wasMessageSent;
}

private void trace(Payload payload) {
Expand All @@ -141,30 +139,4 @@ private void trace(Payload payload) {
}
}
}

//TODO remove this and use retrieveNotifications provided by Helium
private NotificationList retrieveNotifications(LHAccess access) throws HttpException {
Response response = api.path("notifications").queryParam("client", access.clientId).queryParam("since", access.last).queryParam("size", 100).request(MediaType.APPLICATION_JSON).accept(MediaType.APPLICATION_JSON).header(HttpHeaders.AUTHORIZATION, bearer(access.token)).get();

int status = response.getStatus();

if (status == 200) {
return response.readEntity(NotificationList.class);
}

if (status == 404) { //todo what???
return response.readEntity(NotificationList.class);
}

if (status == 401) { //todo nginx returns text/html for 401. Cannot deserialize as json
response.readEntity(String.class);
throw new AuthException(status);
}

if (status == 403) {
throw response.readEntity(AuthException.class);
}

throw response.readEntity(HttpException.class);
}
}
32 changes: 23 additions & 9 deletions src/main/java/com/wire/bots/hold/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.wire.bots.cryptobox.CryptoException;
import com.wire.bots.hold.DAO.AccessDAO;
import com.wire.bots.hold.DAO.EventsDAO;
import com.wire.bots.hold.DAO.MetadataDAO;
import com.wire.bots.hold.filters.ServiceAuthenticationFilter;
import com.wire.bots.hold.healthchecks.SanityCheck;
import com.wire.bots.hold.monitoring.RequestMdcFactoryFilter;
Expand Down Expand Up @@ -63,12 +64,13 @@

import javax.ws.rs.client.Client;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;

public class Service extends Application<Config> {
public static Service instance;
public static MetricRegistry metrics;
public static String API_DOMAIN;
protected Config config;
protected Environment environment;
protected Jdbi jdbi;
Expand Down Expand Up @@ -101,7 +103,7 @@ protected SwaggerBundleConfiguration getSwaggerBundleConfiguration(Config config
}

@Override
public void run(Config config, Environment environment) {
public void run(Config config, Environment environment) throws ExecutionException, InterruptedException {
this.config = config;
this.environment = environment;
Service.metrics = environment.metrics();
Expand All @@ -118,6 +120,7 @@ public void run(Config config, Environment environment) {

final AccessDAO accessDAO = jdbi.onDemand(AccessDAO.class);
final EventsDAO eventsDAO = jdbi.onDemand(EventsDAO.class);
final MetadataDAO metadataDAO = jdbi.onDemand(MetadataDAO.class);

final DeviceManagementService deviceManagementService = new DeviceManagementService(accessDAO, cf);

Expand All @@ -142,13 +145,28 @@ public void run(Config config, Environment environment) {

addResource(ServiceAuthenticationFilter.ServiceAuthenticationFeature.class);

environment.healthChecks().register("SanityCheck", new SanityCheck(accessDAO, httpClient));
final Future<?> fallbackDomainFetcher = environment
.lifecycle()
.executorService("fallback_domain_fetcher")
.build()
.submit(
new FallbackDomainFetcher(
new LoginClient(httpClient),
metadataDAO
)
);

fallbackDomainFetcher.get();

environment.healthChecks().register(
"SanityCheck",
new SanityCheck(accessDAO, httpClient)
);

final HoldClientRepo repo = new HoldClientRepo(jdbi, cf, httpClient);
final LoginClient loginClient = new LoginClient(httpClient);

final HoldMessageResource holdMessageResource = new HoldMessageResource(new MessageHandler(jdbi), repo);
final NotificationProcessor notificationProcessor = new NotificationProcessor(httpClient, accessDAO, config, holdMessageResource);
final NotificationProcessor notificationProcessor = new NotificationProcessor(httpClient, accessDAO, holdMessageResource);

environment.lifecycle()
.scheduledExecutorService("notifications")
Expand All @@ -158,10 +176,6 @@ public void run(Config config, Environment environment) {
CollectorRegistry.defaultRegistry.register(new DropwizardExports(metrics));

environment.getApplicationContext().addServlet(MetricsServlet.class, "/metrics");

// todo here
// String res = loginClient.getBackendConfiguration();
// handle res ^
}

public Config getConfig() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.codahale.metrics.health.HealthCheck;
import com.wire.bots.hold.DAO.AccessDAO;
import com.wire.bots.hold.model.database.LHAccess;
import com.wire.bots.hold.utils.Cache;
import com.wire.helium.API;
import com.wire.xenon.backend.models.QualifiedId;
import com.wire.xenon.tools.Logger;
Expand Down Expand Up @@ -36,8 +37,10 @@ protected Result check() {
while (!accessList.isEmpty()) {
Logger.info("SanityCheck: checking %d devices, created: %s", accessList.size(), created);
for (LHAccess access : accessList) {
// TODO(WPB-11287): Use user domain if exists, otherwise default
// TODO: String domain = (access.domain != null) ? access.domain : Cache.DEFAULT_DOMAIN;
boolean hasDevice = api.hasDevice(
new QualifiedId(access.userId, null), // TODO(WPB-11287): Change null to default domain
new QualifiedId(access.userId, Cache.getFallbackDomain()), // TODO(WPB-11287): Change null to default domain
access.clientId
);

Expand Down
11 changes: 11 additions & 0 deletions src/main/java/com/wire/bots/hold/model/Metadata.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package com.wire.bots.hold.model;

public class Metadata {
public String key;
public String value;

public Metadata(String key, String value) {
this.key = key;
this.value = value;
}
}
Loading

0 comments on commit 968dc21

Please sign in to comment.