diff --git a/pom.xml b/pom.xml index b5e93dc..33772e3 100644 --- a/pom.xml +++ b/pom.xml @@ -135,6 +135,12 @@ simpleclient_servlet ${prometheus.version} + + org.mockito + mockito-core + 5.14.1 + test + diff --git a/src/main/java/com/wire/bots/hold/DAO/MetadataDAO.java b/src/main/java/com/wire/bots/hold/DAO/MetadataDAO.java new file mode 100644 index 0000000..933d83e --- /dev/null +++ b/src/main/java/com/wire/bots/hold/DAO/MetadataDAO.java @@ -0,0 +1,20 @@ +package com.wire.bots.hold.DAO; + +import com.wire.bots.hold.model.Metadata; +import org.jdbi.v3.sqlobject.config.RegisterColumnMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +public interface MetadataDAO { + String FALLBACK_DOMAIN_KEY = "FALLBACK_DOMAIN_KEY"; + + @SqlUpdate("INSERT INTO Metadata (key, value)" + + "VALUES (:key, :value)" + + "ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value") + int insert(@Bind("key") String key, @Bind("value") String value); + + @SqlQuery("SELECT key, value FROM Metadata WHERE key = :key LIMIT 1") + @RegisterColumnMapper(MetadataResultSetMapper.class) + Metadata get(@Bind("key") String key); +} diff --git a/src/main/java/com/wire/bots/hold/DAO/MetadataResultSetMapper.java b/src/main/java/com/wire/bots/hold/DAO/MetadataResultSetMapper.java new file mode 100644 index 0000000..526985e --- /dev/null +++ b/src/main/java/com/wire/bots/hold/DAO/MetadataResultSetMapper.java @@ -0,0 +1,19 @@ +package com.wire.bots.hold.DAO; + +import com.wire.bots.hold.model.Metadata; +import org.jdbi.v3.core.mapper.ColumnMapper; +import org.jdbi.v3.core.statement.StatementContext; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public class MetadataResultSetMapper implements ColumnMapper { + + @Override + public Metadata map(ResultSet rs, int columnNumber, StatementContext ctx) throws SQLException { + return new Metadata( + rs.getString("key"), + rs.getString("value") + ); + } +} diff --git a/src/main/java/com/wire/bots/hold/FallbackDomainFetcher.java b/src/main/java/com/wire/bots/hold/FallbackDomainFetcher.java new file mode 100644 index 0000000..7c7f0fc --- /dev/null +++ b/src/main/java/com/wire/bots/hold/FallbackDomainFetcher.java @@ -0,0 +1,69 @@ +package com.wire.bots.hold; + +import com.wire.bots.hold.DAO.MetadataDAO; +import com.wire.bots.hold.model.Metadata; +import com.wire.bots.hold.utils.Cache; +import com.wire.helium.LoginClient; +import com.wire.helium.models.BackendConfiguration; +import com.wire.xenon.exceptions.HttpException; +import com.wire.xenon.tools.Logger; + +import javax.ws.rs.ProcessingException; + +public class FallbackDomainFetcher implements Runnable { + + private final LoginClient loginClient; + private final MetadataDAO metadataDAO; + + /** + * Fetcher and handler for fallback domain. + *

+ * Fetches from API and compares against database value (if any), then inserts into database and updates cache value. + * If value received from the API is different from what is saved in the database, a [RuntimeException] is thrown. + *

+ *

+ * This fallback domain is necessary for LegalHold to work with Federation (as it needs id@domain) and not just the ID anymore. + * In case there is a mismatch we are throwing a RuntimeException so it stops the execution of this app, so in an event + * of already having a defined default domain saved in the database and this app restarts with a different domain + * we don't get mismatching domains. + *

+ * @param loginClient [{@link LoginClient}] as API to get backend configuration containing default domain. + * @param metadataDAO [{@link MetadataDAO}] as DAO to get/insert default domain to database. + * + * @throws RuntimeException if received domain from API is different from the one saved in the database. + */ + FallbackDomainFetcher(LoginClient loginClient, MetadataDAO metadataDAO) { + this.loginClient = loginClient; + this.metadataDAO = metadataDAO; + } + + @Override + public void run() { + if (Cache.getFallbackDomain() != null) { return; } + + Metadata metadata = metadataDAO.get(MetadataDAO.FALLBACK_DOMAIN_KEY); + try { + BackendConfiguration apiVersionResponse = loginClient.getBackendConfiguration(); + + if (metadata == null) { + metadataDAO.insert(MetadataDAO.FALLBACK_DOMAIN_KEY, apiVersionResponse.domain); + Cache.setFallbackDomain(apiVersionResponse.domain); + } else { + if (metadata.value.equals(apiVersionResponse.domain)) { + Cache.setFallbackDomain(apiVersionResponse.domain); + } else { + String formattedExceptionMessage = String.format( + "Database already has a default domain as %s and instead we got %s from the Backend API.", + metadata.value, + apiVersionResponse.domain + ); + throw new RuntimeException(formattedExceptionMessage); + } + } + } catch (HttpException exception) { + Logger.exception(exception, "FallbackDomainFetcher.run, exception: %s", exception.getMessage()); + } catch (ProcessingException pexception) { + Logger.info("FallbackDomainFetcher.run, ignoring test exceptions"); + } + } +} diff --git a/src/main/java/com/wire/bots/hold/NotificationProcessor.java b/src/main/java/com/wire/bots/hold/NotificationProcessor.java index 6a8bc86..b28534f 100644 --- a/src/main/java/com/wire/bots/hold/NotificationProcessor.java +++ b/src/main/java/com/wire/bots/hold/NotificationProcessor.java @@ -3,38 +3,34 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.wire.bots.hold.DAO.AccessDAO; -import com.wire.bots.hold.model.Notification; -import com.wire.bots.hold.model.NotificationList; import com.wire.bots.hold.model.database.LHAccess; +import com.wire.helium.API; import com.wire.helium.LoginClient; import com.wire.helium.models.Access; +import com.wire.helium.models.Event; +import com.wire.helium.models.NotificationList; import com.wire.xenon.backend.models.Payload; import com.wire.xenon.exceptions.AuthException; import com.wire.xenon.exceptions.HttpException; import com.wire.xenon.tools.Logger; import javax.ws.rs.client.Client; -import javax.ws.rs.client.WebTarget; import javax.ws.rs.core.Cookie; -import javax.ws.rs.core.HttpHeaders; -import javax.ws.rs.core.MediaType; -import javax.ws.rs.core.Response; import java.util.List; import java.util.UUID; import java.util.logging.Level; public class NotificationProcessor implements Runnable { + private static final int DEFAULT_NOTIFICATION_SIZE = 100; + private final Client client; private final AccessDAO accessDAO; private final HoldMessageResource messageResource; - private final WebTarget api; - NotificationProcessor(Client client, AccessDAO accessDAO, Config config, HoldMessageResource messageResource) { + NotificationProcessor(Client client, AccessDAO accessDAO, HoldMessageResource messageResource) { this.client = client; this.accessDAO = accessDAO; this.messageResource = messageResource; - - api = client.target(config.apiHost); } @Override @@ -58,7 +54,10 @@ private Access getAccess(Cookie cookie) throws HttpException { private void process(LHAccess device) { UUID userId = device.userId; + try { + final API api = new API(client, null, device.token); + Logger.debug("`GET /notifications`: user: %s, last: %s", userId, device.last); String cookieValue = device.cookie; @@ -76,39 +75,38 @@ private void process(LHAccess device) { device.token = access.getAccessToken(); - NotificationList notificationList = retrieveNotifications(device); + NotificationList notificationList = api.retrieveNotifications( + device.clientId, + device.last, + DEFAULT_NOTIFICATION_SIZE + ); process(userId, notificationList); - } catch (AuthException e) { accessDAO.disable(userId); Logger.info("Disabled LH device for user: %s, error: %s", userId, e.getMessage()); + } catch (HttpException e) { + Logger.exception(e, "NotificationProcessor: Couldn't retrieve notifications, error: %s", e.getMessage()); } catch (Exception e) { Logger.exception(e, "NotificationProcessor: user: %s, last: %s, error: %s", userId, device.last, e.getMessage()); } } - private static String bearer(String token) { - return token == null ? null : String.format("Bearer %s", token); - } - private void process(UUID userId, NotificationList notificationList) { - - for (Notification notif : notificationList.notifications) { - for (Payload payload : notif.payload) { - if (!process(userId, payload, notif.id)) { - Logger.error("Failed to process: user: %s, notif: %s", userId, notif.id); - //return; + for (Event event : notificationList.notifications) { + for (Payload payload : event.payload) { + if (!process(userId, payload, event.id)) { + Logger.error("Failed to process: user: %s, event: %s", userId, event.id); } else { - Logger.debug("Processed: `%s` conv: %s, user: %s, notifId: %s", + Logger.debug("Processed: `%s` conv: %s, user: %s, eventId: %s", payload.type, payload.conversation, userId, - notif.id); + event.id); } } - accessDAO.updateLast(userId, notif.id); + accessDAO.updateLast(userId, event.id); } } @@ -122,13 +120,13 @@ private boolean process(UUID userId, Payload payload, UUID id) { if (payload.from == null || payload.data == null) return true; - final boolean b = messageResource.onNewMessage(userId, id, payload); + final boolean wasMessageSent = messageResource.onNewMessage(userId, id, payload); - if (!b) { + if (!wasMessageSent) { Logger.error("process: `%s` user: %s, from: %s:%s, error: %s", payload.type, userId, payload.from, payload.data.sender); } - return b; + return wasMessageSent; } private void trace(Payload payload) { @@ -141,30 +139,4 @@ private void trace(Payload payload) { } } } - - //TODO remove this and use retrieveNotifications provided by Helium - private NotificationList retrieveNotifications(LHAccess access) throws HttpException { - Response response = api.path("notifications").queryParam("client", access.clientId).queryParam("since", access.last).queryParam("size", 100).request(MediaType.APPLICATION_JSON).accept(MediaType.APPLICATION_JSON).header(HttpHeaders.AUTHORIZATION, bearer(access.token)).get(); - - int status = response.getStatus(); - - if (status == 200) { - return response.readEntity(NotificationList.class); - } - - if (status == 404) { //todo what??? - return response.readEntity(NotificationList.class); - } - - if (status == 401) { //todo nginx returns text/html for 401. Cannot deserialize as json - response.readEntity(String.class); - throw new AuthException(status); - } - - if (status == 403) { - throw response.readEntity(AuthException.class); - } - - throw response.readEntity(HttpException.class); - } } diff --git a/src/main/java/com/wire/bots/hold/Service.java b/src/main/java/com/wire/bots/hold/Service.java index 67d3b3a..aa470cd 100644 --- a/src/main/java/com/wire/bots/hold/Service.java +++ b/src/main/java/com/wire/bots/hold/Service.java @@ -22,6 +22,7 @@ import com.wire.bots.cryptobox.CryptoException; import com.wire.bots.hold.DAO.AccessDAO; import com.wire.bots.hold.DAO.EventsDAO; +import com.wire.bots.hold.DAO.MetadataDAO; import com.wire.bots.hold.filters.ServiceAuthenticationFilter; import com.wire.bots.hold.healthchecks.SanityCheck; import com.wire.bots.hold.monitoring.RequestMdcFactoryFilter; @@ -63,12 +64,13 @@ import javax.ws.rs.client.Client; import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; public class Service extends Application { public static Service instance; public static MetricRegistry metrics; - public static String API_DOMAIN; protected Config config; protected Environment environment; protected Jdbi jdbi; @@ -101,7 +103,7 @@ protected SwaggerBundleConfiguration getSwaggerBundleConfiguration(Config config } @Override - public void run(Config config, Environment environment) { + public void run(Config config, Environment environment) throws ExecutionException, InterruptedException { this.config = config; this.environment = environment; Service.metrics = environment.metrics(); @@ -118,6 +120,7 @@ public void run(Config config, Environment environment) { final AccessDAO accessDAO = jdbi.onDemand(AccessDAO.class); final EventsDAO eventsDAO = jdbi.onDemand(EventsDAO.class); + final MetadataDAO metadataDAO = jdbi.onDemand(MetadataDAO.class); final DeviceManagementService deviceManagementService = new DeviceManagementService(accessDAO, cf); @@ -142,13 +145,28 @@ public void run(Config config, Environment environment) { addResource(ServiceAuthenticationFilter.ServiceAuthenticationFeature.class); - environment.healthChecks().register("SanityCheck", new SanityCheck(accessDAO, httpClient)); + final Future fallbackDomainFetcher = environment + .lifecycle() + .executorService("fallback_domain_fetcher") + .build() + .submit( + new FallbackDomainFetcher( + new LoginClient(httpClient), + metadataDAO + ) + ); + + fallbackDomainFetcher.get(); + + environment.healthChecks().register( + "SanityCheck", + new SanityCheck(accessDAO, httpClient) + ); final HoldClientRepo repo = new HoldClientRepo(jdbi, cf, httpClient); - final LoginClient loginClient = new LoginClient(httpClient); final HoldMessageResource holdMessageResource = new HoldMessageResource(new MessageHandler(jdbi), repo); - final NotificationProcessor notificationProcessor = new NotificationProcessor(httpClient, accessDAO, config, holdMessageResource); + final NotificationProcessor notificationProcessor = new NotificationProcessor(httpClient, accessDAO, holdMessageResource); environment.lifecycle() .scheduledExecutorService("notifications") @@ -158,10 +176,6 @@ public void run(Config config, Environment environment) { CollectorRegistry.defaultRegistry.register(new DropwizardExports(metrics)); environment.getApplicationContext().addServlet(MetricsServlet.class, "/metrics"); - - // todo here - // String res = loginClient.getBackendConfiguration(); - // handle res ^ } public Config getConfig() { diff --git a/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java b/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java index 0408aed..d6e6047 100644 --- a/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java +++ b/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java @@ -3,6 +3,7 @@ import com.codahale.metrics.health.HealthCheck; import com.wire.bots.hold.DAO.AccessDAO; import com.wire.bots.hold.model.database.LHAccess; +import com.wire.bots.hold.utils.Cache; import com.wire.helium.API; import com.wire.xenon.backend.models.QualifiedId; import com.wire.xenon.tools.Logger; @@ -36,8 +37,10 @@ protected Result check() { while (!accessList.isEmpty()) { Logger.info("SanityCheck: checking %d devices, created: %s", accessList.size(), created); for (LHAccess access : accessList) { + // TODO(WPB-11287): Use user domain if exists, otherwise default + // TODO: String domain = (access.domain != null) ? access.domain : Cache.DEFAULT_DOMAIN; boolean hasDevice = api.hasDevice( - new QualifiedId(access.userId, null), // TODO(WPB-11287): Change null to default domain + new QualifiedId(access.userId, Cache.getFallbackDomain()), // TODO(WPB-11287): Change null to default domain access.clientId ); diff --git a/src/main/java/com/wire/bots/hold/model/Metadata.java b/src/main/java/com/wire/bots/hold/model/Metadata.java new file mode 100644 index 0000000..f145827 --- /dev/null +++ b/src/main/java/com/wire/bots/hold/model/Metadata.java @@ -0,0 +1,11 @@ +package com.wire.bots.hold.model; + +public class Metadata { + public String key; + public String value; + + public Metadata(String key, String value) { + this.key = key; + this.value = value; + } +} diff --git a/src/main/java/com/wire/bots/hold/model/Notification.java b/src/main/java/com/wire/bots/hold/model/Notification.java deleted file mode 100644 index 5bfba19..0000000 --- a/src/main/java/com/wire/bots/hold/model/Notification.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.wire.bots.hold.model; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.wire.xenon.backend.models.Payload; - -import javax.validation.constraints.NotNull; -import java.util.List; -import java.util.UUID; - -@JsonIgnoreProperties(ignoreUnknown = true) -public class Notification { - @JsonProperty - @NotNull - public List payload; - - @JsonProperty - @NotNull - public UUID id; -} diff --git a/src/main/java/com/wire/bots/hold/model/NotificationList.java b/src/main/java/com/wire/bots/hold/model/NotificationList.java deleted file mode 100644 index 76a31f8..0000000 --- a/src/main/java/com/wire/bots/hold/model/NotificationList.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.wire.bots.hold.model; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; - -import javax.validation.constraints.NotNull; -import java.util.ArrayList; - -@JsonIgnoreProperties(ignoreUnknown = true) -public class NotificationList { - @JsonProperty("has_more") - @NotNull - public Boolean hasMore; - - @JsonProperty - @NotNull - public ArrayList notifications; -} diff --git a/src/main/java/com/wire/bots/hold/model/api/shared/ApiVersionResponse.java b/src/main/java/com/wire/bots/hold/model/api/shared/ApiVersionResponse.java index 68328ee..5ede5ab 100644 --- a/src/main/java/com/wire/bots/hold/model/api/shared/ApiVersionResponse.java +++ b/src/main/java/com/wire/bots/hold/model/api/shared/ApiVersionResponse.java @@ -11,4 +11,4 @@ public class ApiVersionResponse { public ApiVersionResponse(List supported) { this.supported = supported; } -} \ No newline at end of file +} diff --git a/src/main/java/com/wire/bots/hold/model/dto/InitializedDeviceDTO.java b/src/main/java/com/wire/bots/hold/model/dto/InitializedDeviceDTO.java index 567062d..4266ec4 100644 --- a/src/main/java/com/wire/bots/hold/model/dto/InitializedDeviceDTO.java +++ b/src/main/java/com/wire/bots/hold/model/dto/InitializedDeviceDTO.java @@ -26,4 +26,4 @@ public PreKey getLastPreKey() { public String getFingerprint() { return fingerprint; } -} \ No newline at end of file +} diff --git a/src/main/java/com/wire/bots/hold/monitoring/ApiVersionResource.java b/src/main/java/com/wire/bots/hold/monitoring/ApiVersionResource.java index 5a0c771..16f0dc9 100644 --- a/src/main/java/com/wire/bots/hold/monitoring/ApiVersionResource.java +++ b/src/main/java/com/wire/bots/hold/monitoring/ApiVersionResource.java @@ -39,7 +39,7 @@ public Response apiVersion() { ApiVersionResponse response = new ApiVersionResponse(List.of(0,1)); return Response - .ok(response) - .build(); + .ok(response) + .build(); } } diff --git a/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java b/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java index 724fee7..c44e87d 100644 --- a/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java +++ b/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java @@ -35,7 +35,7 @@ public class StatusResource { @ApiOperation(value = "Status") public Response statusEmpty() { return Response - .ok() - .build(); + .ok() + .build(); } } diff --git a/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java b/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java index a95e697..7637839 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java @@ -67,6 +67,8 @@ public Response list(@ApiParam @PathParam("conversationId") UUID conversationId, try { List events = eventsDAO.listAllAsc(conversationId); + // TODO(WPB-11287) Verify default domain + testAPI(); Cache cache = new Cache(api, assetsDAO); diff --git a/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java b/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java index 6d82c84..0f91e83 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java @@ -29,9 +29,9 @@ public ConfirmResourceV0(DeviceManagementService deviceManagementService) { @ServiceAuthorization @ApiOperation(value = "Confirm legal hold device") @ApiResponses(value = { - @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), - @ApiResponse(code = 500, message = "Something went wrong"), - @ApiResponse(code = 200, message = "Legal Hold Device enabled")}) + @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), + @ApiResponse(code = 500, message = "Something went wrong"), + @ApiResponse(code = 200, message = "Legal Hold Device enabled")}) public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV0 payload) { try { deviceManagementService.confirmDevice( @@ -52,6 +52,4 @@ public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV0 payload) { .build(); } } - - } diff --git a/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java b/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java index 6f444f4..38dfcb9 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java @@ -31,9 +31,9 @@ public InitiateResourceV0(DeviceManagementService deviceManagementService) { @ServiceAuthorization @ApiOperation(value = "Initiate", response = InitResponse.class) @ApiResponses(value = { - @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), - @ApiResponse(code = 500, message = "Something went wrong"), - @ApiResponse(code = 200, message = "CryptoBox initiated")}) + @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), + @ApiResponse(code = 500, message = "Something went wrong"), + @ApiResponse(code = 200, message = "CryptoBox initiated")}) public Response initiate(@ApiParam @Valid @NotNull InitPayloadV0 init) { try { final InitializedDeviceDTO initializedDeviceDTO = diff --git a/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java b/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java index 805c6b6..7068233 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java @@ -28,9 +28,9 @@ public RemoveResourceV0(DeviceManagementService deviceManagementService) { @ServiceAuthorization @ApiOperation(value = "Remove legal hold device") @ApiResponses(value = { - @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), - @ApiResponse(code = 500, message = "Something went wrong"), - @ApiResponse(code = 200, message = "Legal Hold Device was removed")}) + @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), + @ApiResponse(code = 500, message = "Something went wrong"), + @ApiResponse(code = 200, message = "Legal Hold Device was removed")}) public Response remove(@ApiParam @Valid InitPayloadV0 payload) { try { deviceManagementService.removeDevice( diff --git a/src/main/java/com/wire/bots/hold/resource/v1/backend/ConfirmResourceV1.java b/src/main/java/com/wire/bots/hold/resource/v1/backend/ConfirmResourceV1.java index 7395785..a32883d 100644 --- a/src/main/java/com/wire/bots/hold/resource/v1/backend/ConfirmResourceV1.java +++ b/src/main/java/com/wire/bots/hold/resource/v1/backend/ConfirmResourceV1.java @@ -28,9 +28,9 @@ public ConfirmResourceV1(DeviceManagementService deviceManagementService) { @ServiceAuthorization @ApiOperation(value = "Confirm legal hold device") @ApiResponses(value = { - @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), - @ApiResponse(code = 500, message = "Something went wrong"), - @ApiResponse(code = 200, message = "Legal Hold Device enabled")}) + @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), + @ApiResponse(code = 500, message = "Something went wrong"), + @ApiResponse(code = 200, message = "Legal Hold Device enabled")}) public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV1 payload) { try { deviceManagementService.confirmDevice( diff --git a/src/main/java/com/wire/bots/hold/resource/v1/backend/RemoveResourceV1.java b/src/main/java/com/wire/bots/hold/resource/v1/backend/RemoveResourceV1.java index e70bd34..d38ebb9 100644 --- a/src/main/java/com/wire/bots/hold/resource/v1/backend/RemoveResourceV1.java +++ b/src/main/java/com/wire/bots/hold/resource/v1/backend/RemoveResourceV1.java @@ -27,9 +27,9 @@ public RemoveResourceV1(DeviceManagementService deviceManagementService) { @ServiceAuthorization @ApiOperation(value = "Remove legal hold device") @ApiResponses(value = { - @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), - @ApiResponse(code = 500, message = "Something went wrong"), - @ApiResponse(code = 200, message = "Legal Hold Device was removed")}) + @ApiResponse(code = 400, message = "Bad request. Invalid Payload"), + @ApiResponse(code = 500, message = "Something went wrong"), + @ApiResponse(code = 200, message = "Legal Hold Device was removed")}) public Response remove(@ApiParam @Valid InitPayloadV1 payload) { try { deviceManagementService.removeDevice(payload.userId, payload.teamId); @@ -44,6 +44,4 @@ public Response remove(@ApiParam @Valid InitPayloadV1 payload) { .build(); } } - - } diff --git a/src/main/java/com/wire/bots/hold/utils/Cache.java b/src/main/java/com/wire/bots/hold/utils/Cache.java index e0e55c2..95b4ee4 100644 --- a/src/main/java/com/wire/bots/hold/utils/Cache.java +++ b/src/main/java/com/wire/bots/hold/utils/Cache.java @@ -13,10 +13,11 @@ import java.util.concurrent.ConcurrentHashMap; public class Cache { - // TODO(WPB-11287): Add default domain here - private static final ConcurrentHashMap assets = new ConcurrentHashMap<>();// - private static final ConcurrentHashMap users = new ConcurrentHashMap<>();// - private static final ConcurrentHashMap profiles = new ConcurrentHashMap<>();// + private static String FALLBACK_DOMAIN = null; + + private static final ConcurrentHashMap assets = new ConcurrentHashMap<>(); // + private static final ConcurrentHashMap users = new ConcurrentHashMap<>(); // + private static final ConcurrentHashMap profiles = new ConcurrentHashMap<>(); // private final API api; private final AssetsDAO assetsDAO; @@ -25,6 +26,14 @@ public Cache(API api, AssetsDAO assetsDAO) { this.assetsDAO = assetsDAO; } + public static void setFallbackDomain(String domain) { + FALLBACK_DOMAIN = domain; + } + + public static String getFallbackDomain() { + return FALLBACK_DOMAIN; + } + @Nullable public File getAssetFile(UUID messageId) { @@ -57,7 +66,6 @@ public File getProfileImage(User user) { } public User getUser(QualifiedId userId) { - // TODO(WPB-11287): Fetch first in map then check API return users.computeIfAbsent(userId, k -> { try { return api.getUser(userId); diff --git a/src/main/resources/db/migration/V107__add_metadata_table.sql b/src/main/resources/db/migration/V107__add_metadata_table.sql new file mode 100644 index 0000000..ef255d4 --- /dev/null +++ b/src/main/resources/db/migration/V107__add_metadata_table.sql @@ -0,0 +1,4 @@ +CREATE TABLE Metadata ( + key VARCHAR(255) PRIMARY KEY, + value VARCHAR(255) NOT NULL +); diff --git a/src/test/java/com/wire/bots/hold/ConfirmRemoveResourceV1Test.java b/src/test/java/com/wire/bots/hold/ConfirmRemoveResourceV1Test.java index 08ed44b..41474f1 100644 --- a/src/test/java/com/wire/bots/hold/ConfirmRemoveResourceV1Test.java +++ b/src/test/java/com/wire/bots/hold/ConfirmRemoveResourceV1Test.java @@ -24,7 +24,8 @@ public class ConfirmRemoveResourceV1Test { private static final String TOKEN = "dummy"; private static final DropwizardTestSupport SUPPORT = new DropwizardTestSupport<>( Service.class, "hold.yaml", - ConfigOverride.config("token", TOKEN)); + ConfigOverride.config("token", TOKEN), + ConfigOverride.config("apiHost", "dummy")); private static Client client; private static AccessDAO accessDAO; diff --git a/src/test/java/com/wire/bots/hold/DatabaseTest.java b/src/test/java/com/wire/bots/hold/DatabaseTest.java index 684239e..3bb7f88 100644 --- a/src/test/java/com/wire/bots/hold/DatabaseTest.java +++ b/src/test/java/com/wire/bots/hold/DatabaseTest.java @@ -5,6 +5,8 @@ import com.wire.bots.hold.DAO.AccessDAO; import com.wire.bots.hold.DAO.AssetsDAO; import com.wire.bots.hold.DAO.EventsDAO; +import com.wire.bots.hold.DAO.MetadataDAO; +import com.wire.bots.hold.model.Metadata; import com.wire.bots.hold.model.database.Event; import com.wire.bots.hold.model.database.LHAccess; import com.wire.xenon.backend.models.QualifiedId; @@ -19,13 +21,15 @@ public class DatabaseTest { private static final DropwizardTestSupport SUPPORT = new DropwizardTestSupport<>( - Service.class, "hold.yaml", - ConfigOverride.config("token", "dummy")); + Service.class, "hold.yaml", + ConfigOverride.config("token", "dummy"), + ConfigOverride.config("apiHost", "dummy")); private static final ObjectMapper mapper = new ObjectMapper(); private static AssetsDAO assetsDAO; private static EventsDAO eventsDAO; private static AccessDAO accessDAO; + private static MetadataDAO metadataDAO; @BeforeClass public static void init() throws Exception { @@ -35,6 +39,7 @@ public static void init() throws Exception { eventsDAO = app.getJdbi().onDemand(EventsDAO.class); assetsDAO = app.getJdbi().onDemand(AssetsDAO.class); accessDAO = app.getJdbi().onDemand(AccessDAO.class); + metadataDAO = app.getJdbi().onDemand(MetadataDAO.class); } @AfterClass @@ -114,4 +119,18 @@ public void accessTests() { assert lhAccess2 != null; assert lhAccess2.created.equals(lhAccess.created); } + + @Test + public void metadataTests() { + String dummyKey = MetadataDAO.FALLBACK_DOMAIN_KEY + UUID.randomUUID(); + + Metadata nullMetadata = metadataDAO.get(dummyKey); + assert nullMetadata == null; + + int insert = metadataDAO.insert(dummyKey, "dummy_domain"); + Metadata metadata = metadataDAO.get(dummyKey); + + assert metadata != null; + assert metadata.value.equals("dummy_domain"); + } } diff --git a/src/test/java/com/wire/bots/hold/FallbackDomainFetcherTest.java b/src/test/java/com/wire/bots/hold/FallbackDomainFetcherTest.java new file mode 100644 index 0000000..d35a8b2 --- /dev/null +++ b/src/test/java/com/wire/bots/hold/FallbackDomainFetcherTest.java @@ -0,0 +1,99 @@ +package com.wire.bots.hold; + +import com.wire.bots.hold.DAO.MetadataDAO; +import com.wire.bots.hold.model.Metadata; +import com.wire.bots.hold.utils.Cache; +import com.wire.helium.LoginClient; +import com.wire.helium.models.BackendConfiguration; +import com.wire.xenon.exceptions.HttpException; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public class FallbackDomainFetcherTest { + + private MetadataDAO metadataDAO; + private LoginClient loginClient; + private FallbackDomainFetcher fallbackDomainFetcher; + + @Before + public void before() { + // Clears cached domain + Cache.setFallbackDomain(null); + metadataDAO = mock(MetadataDAO.class); + loginClient = mock(LoginClient.class); + fallbackDomainFetcher = new FallbackDomainFetcher(loginClient, metadataDAO); + } + + @After + public void after() { + // Clears cached domain + Cache.setFallbackDomain(null); + } + + @Test + public void givenNoFallbackDomainInCacheAndDatabase_whenExecuting_thenFetchFromApiAndStoreInCacheAndDatabase() throws HttpException { + // given + BackendConfiguration backendConfiguration = new BackendConfiguration(); + backendConfiguration.domain = "dummy_domain_3"; + + when(metadataDAO.get(MetadataDAO.FALLBACK_DOMAIN_KEY)).thenReturn(null); + when(metadataDAO.insert(any(), any())).thenReturn(1); + when(loginClient.getBackendConfiguration()).thenReturn(backendConfiguration); + + // when + fallbackDomainFetcher.run(); + + // then + assert Cache.getFallbackDomain().equals("dummy_domain_3"); + verify(metadataDAO, times(1)).insert(MetadataDAO.FALLBACK_DOMAIN_KEY, backendConfiguration.domain); + } + + @Test + public void givenNoFallbackDomainInCache_whenExecuting_thenFetchFromAPIAndCompareWithDatabase() throws HttpException { + // given + BackendConfiguration backendConfiguration = new BackendConfiguration(); + backendConfiguration.domain = "dummy_domain_2"; + + Metadata metadata = new Metadata(MetadataDAO.FALLBACK_DOMAIN_KEY, "dummy_domain_2"); + + when(metadataDAO.get(MetadataDAO.FALLBACK_DOMAIN_KEY)).thenReturn(metadata); + when(loginClient.getBackendConfiguration()).thenReturn(backendConfiguration); + + // when + fallbackDomainFetcher.run(); + + // then + assert Cache.getFallbackDomain().equals("dummy_domain_2"); + } + + @Test(expected=RuntimeException.class) + public void givenNoFallbackDomainInCache_whenExecutingAndApiReturnsDifferentDomainFromDatabase_thenThrowRuntimeException() throws HttpException, RuntimeException { + // given + BackendConfiguration backendConfiguration = new BackendConfiguration(); + backendConfiguration.domain = "dummy_domain_1"; + + Metadata metadata = new Metadata(MetadataDAO.FALLBACK_DOMAIN_KEY, "dummy_domain_2"); + + when(metadataDAO.get(MetadataDAO.FALLBACK_DOMAIN_KEY)).thenReturn(metadata); + when(loginClient.getBackendConfiguration()).thenReturn(backendConfiguration); + + // when + fallbackDomainFetcher.run(); + } + + @Test + public void givenFallbackDomainInCache_whenExecuting_thenReturnAndIgnoreDatabaseAndApiCalls() { + // given + Cache.setFallbackDomain("dummy_domain_0"); + + // when + fallbackDomainFetcher.run(); + + // then + assert Cache.getFallbackDomain().equals("dummy_domain_0"); + } +} diff --git a/src/test/java/com/wire/bots/hold/InitiateResourceV1Test.java b/src/test/java/com/wire/bots/hold/InitiateResourceV1Test.java index 1dbe8ad..33d754c 100644 --- a/src/test/java/com/wire/bots/hold/InitiateResourceV1Test.java +++ b/src/test/java/com/wire/bots/hold/InitiateResourceV1Test.java @@ -22,7 +22,8 @@ public class InitiateResourceV1Test { private static final String TOKEN = "dummy"; private static final DropwizardTestSupport SUPPORT = new DropwizardTestSupport<>( Service.class, "hold.yaml", - ConfigOverride.config("token", TOKEN)); + ConfigOverride.config("token", TOKEN), + ConfigOverride.config("apiHost", "dummy")); private static Client client; @BeforeClass diff --git a/src/test/java/com/wire/bots/hold/utils/CacheTest.java b/src/test/java/com/wire/bots/hold/utils/CacheTest.java new file mode 100644 index 0000000..ddbf2dd --- /dev/null +++ b/src/test/java/com/wire/bots/hold/utils/CacheTest.java @@ -0,0 +1,38 @@ +package com.wire.bots.hold.utils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class CacheTest { + + @Before + public void before() { + // Clears cached domain + Cache.setFallbackDomain(null); + } + + @After + public void after() { + // Clears cached domain + Cache.setFallbackDomain(null); + } + + @Test + public void verifyDefaultDomainIsSetCorrectly() { + String firstDomain = Cache.getFallbackDomain(); + assert firstDomain == null; + + Cache.setFallbackDomain("dummy_domain"); + String secondDomain = Cache.getFallbackDomain(); + + assert secondDomain != null; + assert secondDomain.equals("dummy_domain"); + + Cache.setFallbackDomain("dummy_domain_3"); + String thirdDomain = Cache.getFallbackDomain(); + + assert thirdDomain != null; + assert thirdDomain.equals("dummy_domain_3"); + } +}