diff --git a/src/main/java/com/wire/bots/hold/DAO/AccessDAO.java b/src/main/java/com/wire/bots/hold/DAO/AccessDAO.java index e772614..3b86863 100644 --- a/src/main/java/com/wire/bots/hold/DAO/AccessDAO.java +++ b/src/main/java/com/wire/bots/hold/DAO/AccessDAO.java @@ -10,33 +10,42 @@ import java.util.UUID; public interface AccessDAO { - @SqlUpdate("INSERT INTO Access (userId, clientId, cookie, updated, created, enabled) " + - "VALUES (:userId, :clientId, :cookie, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 1) " + - "ON CONFLICT (userId) DO UPDATE SET cookie = EXCLUDED.cookie, clientId = EXCLUDED.clientId, " + + @SqlUpdate("INSERT INTO Access (userId, userDomain, clientId, cookie, updated, created, enabled) " + + "VALUES (:userId, :userDomain, :clientId, :cookie, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 1) " + + "ON CONFLICT (userId, userDomain) DO UPDATE SET cookie = EXCLUDED.cookie, clientId = EXCLUDED.clientId, " + "updated = EXCLUDED.updated, enabled = EXCLUDED.enabled") int insert(@Bind("userId") UUID userId, + @Bind("userDomain") String userDomain, @Bind("clientId") String clientId, @Bind("cookie") String cookie); - @SqlUpdate("UPDATE Access SET enabled = 0, updated = CURRENT_TIMESTAMP WHERE userId = :userId") - int disable(@Bind("userId") UUID userId); + @SqlUpdate("UPDATE Access SET enabled = 0, updated = CURRENT_TIMESTAMP WHERE userId = :userId " + + "AND (( :userDomain IS NULL AND userDomain IS null ) or ( :userDomain IS NOT NULL AND userDomain = :userDomain ))") + int disable(@Bind("userId") UUID userId, + @Bind("userDomain") String userDomain); - @SqlUpdate("UPDATE Access SET token = :token, cookie = :cookie, updated = CURRENT_TIMESTAMP WHERE userId = :userId") + @SqlUpdate("UPDATE Access SET token = :token, cookie = :cookie, updated = CURRENT_TIMESTAMP WHERE userId = :userId " + + "AND (( :userDomain IS NULL AND userDomain IS null ) OR ( :userDomain IS NOT NULL AND userDomain = :userDomain ))") int update(@Bind("userId") UUID userId, - @Bind("token") String token, - @Bind("cookie") String cookie); + @Bind("userDomain") String userDomain, + @Bind("token") String token, + @Bind("cookie") String cookie); - @SqlUpdate("UPDATE Access SET last = :last, updated = CURRENT_TIMESTAMP WHERE userId = :userId") + @SqlUpdate("UPDATE Access SET last = :last, updated = CURRENT_TIMESTAMP WHERE userId = :userId AND (( :userDomain IS NULL AND userDomain IS null ) " + + "OR ( :userDomain IS NOT NULL AND userDomain = :userDomain ))") int updateLast(@Bind("userId") UUID userId, - @Bind("last") UUID last); + @Bind("userDomain") String userDomain, + @Bind("last") UUID last); @SqlQuery("SELECT * FROM Access WHERE token IS NOT NULL AND enabled = 1 ORDER BY created DESC LIMIT 1") @RegisterColumnMapper(AccessResultSetMapper.class) LHAccess getSingle(); - @SqlQuery("SELECT * FROM Access WHERE userId = :userId") + @SqlQuery("SELECT * FROM Access WHERE userId = :userId AND (( :userDomain IS NULL AND userDomain is null ) " + + "or ( :userDomain is NOT NULL AND userDomain = :userDomain ))") @RegisterColumnMapper(AccessResultSetMapper.class) - LHAccess get(@Bind("userId") UUID userId); + LHAccess get(@Bind("userId") UUID userId, + @Bind("userDomain") String userDomain); @SqlQuery("SELECT * FROM Access WHERE enabled = 1 ORDER BY created DESC") @RegisterColumnMapper(AccessResultSetMapper.class) diff --git a/src/main/java/com/wire/bots/hold/DAO/AccessResultSetMapper.java b/src/main/java/com/wire/bots/hold/DAO/AccessResultSetMapper.java index d3f6bc7..eb1286b 100644 --- a/src/main/java/com/wire/bots/hold/DAO/AccessResultSetMapper.java +++ b/src/main/java/com/wire/bots/hold/DAO/AccessResultSetMapper.java @@ -1,6 +1,7 @@ package com.wire.bots.hold.DAO; import com.wire.bots.hold.model.database.LHAccess; +import com.wire.xenon.backend.models.QualifiedId; import org.jdbi.v3.core.mapper.ColumnMapper; import org.jdbi.v3.core.statement.StatementContext; @@ -13,7 +14,9 @@ public class AccessResultSetMapper implements ColumnMapper { public LHAccess map(ResultSet rs, int columnNumber, StatementContext ctx) throws SQLException { LHAccess LHAccess = new LHAccess(); LHAccess.last = (UUID) rs.getObject("last"); - LHAccess.userId = (UUID) rs.getObject("userId"); + UUID userId = (UUID) rs.getObject("userId"); + String userDomain = rs.getString("userDomain"); + LHAccess.userId = new QualifiedId(userId, userDomain); LHAccess.clientId = rs.getString("clientId"); LHAccess.token = rs.getString("token"); LHAccess.cookie = rs.getString("cookie"); diff --git a/src/main/java/com/wire/bots/hold/DAO/EventsDAO.java b/src/main/java/com/wire/bots/hold/DAO/EventsDAO.java index d3c6511..e79a1cb 100644 --- a/src/main/java/com/wire/bots/hold/DAO/EventsDAO.java +++ b/src/main/java/com/wire/bots/hold/DAO/EventsDAO.java @@ -27,13 +27,25 @@ int insert(@Bind("eventId") UUID eventId, @RegisterColumnMapper(EventsResultSetMapper.class) Event get(@Bind("eventId") UUID eventId); - @SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId ORDER BY time DESC") + @SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND (conversationDomain is NULL OR conversationDomain = :conversationDomain) ORDER BY time DESC") @RegisterColumnMapper(EventsResultSetMapper.class) - List listAll(@Bind("conversationId") UUID conversationId); + List listAllDefaultDomain(@Bind("conversationId") UUID conversationId, + @Bind("conversationDomain") String conversationDomain); - @SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId ORDER BY time ASC") + @SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND conversationDomain = :conversationDomain ORDER BY time DESC") @RegisterColumnMapper(EventsResultSetMapper.class) - List listAllAsc(@Bind("conversationId") UUID conversationId); + List listAll(@Bind("conversationId") UUID conversationId, + @Bind("conversationDomain") String conversationDomain); + + @SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND (conversationDomain is NULL OR conversationDomain = :conversationDomain) ORDER BY time ASC") + @RegisterColumnMapper(EventsResultSetMapper.class) + List listAllDefaultDomainAsc(@Bind("conversationId") UUID conversationId, + @Bind("conversationDomain") String conversationDomain); + + @SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND conversationDomain = :conversationDomain ORDER BY time ASC") + @RegisterColumnMapper(EventsResultSetMapper.class) + List listAllAsc(@Bind("conversationId") UUID conversationId, + @Bind("conversationDomain") String conversationDomain); @SqlQuery("SELECT DISTINCT conversationId, MAX(time) AS time " + "FROM Events " + diff --git a/src/main/java/com/wire/bots/hold/Service.java b/src/main/java/com/wire/bots/hold/Service.java index aa470cd..876ab78 100644 --- a/src/main/java/com/wire/bots/hold/Service.java +++ b/src/main/java/com/wire/bots/hold/Service.java @@ -25,6 +25,7 @@ import com.wire.bots.hold.DAO.MetadataDAO; import com.wire.bots.hold.filters.ServiceAuthenticationFilter; import com.wire.bots.hold.healthchecks.SanityCheck; +import com.wire.bots.hold.monitoring.ApiVersionResource; import com.wire.bots.hold.monitoring.RequestMdcFactoryFilter; import com.wire.bots.hold.monitoring.StatusResource; import com.wire.bots.hold.resource.v0.audit.*; @@ -126,6 +127,7 @@ public void run(Config config, Environment environment) throws ExecutionExceptio // Monitoring resources addResource(new StatusResource()); + addResource(new ApiVersionResource()); addResource(new RequestMdcFactoryFilter()); // Used by Wire Server diff --git a/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java b/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java index d6e6047..c23f530 100644 --- a/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java +++ b/src/main/java/com/wire/bots/hold/healthchecks/SanityCheck.java @@ -5,7 +5,6 @@ import com.wire.bots.hold.model.database.LHAccess; import com.wire.bots.hold.utils.Cache; import com.wire.helium.API; -import com.wire.xenon.backend.models.QualifiedId; import com.wire.xenon.tools.Logger; import javax.ws.rs.client.Client; @@ -37,10 +36,11 @@ protected Result check() { while (!accessList.isEmpty()) { Logger.info("SanityCheck: checking %d devices, created: %s", accessList.size(), created); for (LHAccess access : accessList) { - // TODO(WPB-11287): Use user domain if exists, otherwise default - // TODO: String domain = (access.domain != null) ? access.domain : Cache.DEFAULT_DOMAIN; + if (access.userId.domain == null) { + access.userId.domain = Cache.getFallbackDomain(); + } boolean hasDevice = api.hasDevice( - new QualifiedId(access.userId, Cache.getFallbackDomain()), // TODO(WPB-11287): Change null to default domain + access.userId, access.clientId ); diff --git a/src/main/java/com/wire/bots/hold/model/database/Event.java b/src/main/java/com/wire/bots/hold/model/database/Event.java index a7c49eb..ed5e4bd 100644 --- a/src/main/java/com/wire/bots/hold/model/database/Event.java +++ b/src/main/java/com/wire/bots/hold/model/database/Event.java @@ -5,7 +5,9 @@ public class Event { public UUID eventId; public UUID conversationId; + public String conversationDomain; // Keeping values split instead of using QualifiedId because of HTML templating public UUID userId; + public String userDomain; // Keeping values split instead of using QualifiedId because of HTML templating public String type; public String payload; public String time; diff --git a/src/main/java/com/wire/bots/hold/model/database/LHAccess.java b/src/main/java/com/wire/bots/hold/model/database/LHAccess.java index f6d79c6..568b953 100644 --- a/src/main/java/com/wire/bots/hold/model/database/LHAccess.java +++ b/src/main/java/com/wire/bots/hold/model/database/LHAccess.java @@ -1,10 +1,12 @@ package com.wire.bots.hold.model.database; +import com.wire.xenon.backend.models.QualifiedId; + import java.util.UUID; public class LHAccess { public UUID last; - public UUID userId; + public QualifiedId userId; public String clientId; public String token; public String cookie; diff --git a/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java b/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java index c44e87d..42bd0dd 100644 --- a/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java +++ b/src/main/java/com/wire/bots/hold/monitoring/StatusResource.java @@ -28,7 +28,7 @@ import javax.ws.rs.core.Response; @Api -@Path("/status") +@Path("/{parameter: status|v1/status}") @Produces(MediaType.TEXT_PLAIN) public class StatusResource { @GET diff --git a/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java b/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java index 7637839..0131b97 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/audit/ConversationResource.java @@ -65,9 +65,8 @@ public ConversationResource(Jdbi jdbi, Client httpClient) { public Response list(@ApiParam @PathParam("conversationId") UUID conversationId, @ApiParam @QueryParam("html") boolean isHtml) { try { - List events = eventsDAO.listAllAsc(conversationId); - - // TODO(WPB-11287) Verify default domain + //TODO Get DEFAULT_DOMAIN, then fetch events with domain = null and domain = DEFAULT_DOMAIN + List events = eventsDAO.listAllDefaultDomainAsc(conversationId); testAPI(); diff --git a/src/main/java/com/wire/bots/hold/resource/v0/audit/DevicesResource.java b/src/main/java/com/wire/bots/hold/resource/v0/audit/DevicesResource.java index a3659b9..8d6024b 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/audit/DevicesResource.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/audit/DevicesResource.java @@ -6,8 +6,9 @@ import com.wire.bots.hold.DAO.AccessDAO; import com.wire.bots.hold.filters.ServiceAuthorization; import com.wire.bots.hold.model.database.LHAccess; +import com.wire.bots.hold.utils.CryptoDatabaseFactory; +import com.wire.xenon.backend.models.QualifiedId; import com.wire.xenon.crypto.Crypto; -import com.wire.xenon.factories.CryptoFactory; import com.wire.xenon.tools.Logger; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; @@ -32,10 +33,10 @@ @Produces(MediaType.TEXT_HTML) public class DevicesResource { private final static MustacheFactory mf = new DefaultMustacheFactory(); - private final CryptoFactory cryptoFactory; + private final CryptoDatabaseFactory cryptoFactory; private final AccessDAO accessDAO; - public DevicesResource(AccessDAO accessDAO, CryptoFactory cryptoFactory) { + public DevicesResource(AccessDAO accessDAO, CryptoDatabaseFactory cryptoFactory) { this.cryptoFactory = cryptoFactory; this.accessDAO = accessDAO; } @@ -95,7 +96,7 @@ private String execute(Object model) throws IOException { static class Legal { UUID last; - UUID userId; + QualifiedId userId; String clientId; String fingerprint; String updated; diff --git a/src/main/java/com/wire/bots/hold/resource/v0/audit/EventsResource.java b/src/main/java/com/wire/bots/hold/resource/v0/audit/EventsResource.java index 60236f3..077a2fb 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/audit/EventsResource.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/audit/EventsResource.java @@ -40,8 +40,9 @@ public EventsResource(EventsDAO eventsDAO) { @ApiResponse(code = 200, message = "Wire events")}) public Response list(@ApiParam @PathParam("conversationId") UUID conversationId) { try { + //TODO Get DEFAULT_DOMAIN, then fetch events with domain = null and domain = DEFAULT_DOMAIN Model model = new Model(); - model.events = eventsDAO.listAll(conversationId); + model.events = eventsDAO.listAllDefaultDomain(conversationId); String html = execute(model); return Response. diff --git a/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java b/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java index 0f91e83..02c49d2 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/backend/ConfirmResourceV0.java @@ -35,7 +35,7 @@ public ConfirmResourceV0(DeviceManagementService deviceManagementService) { public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV0 payload) { try { deviceManagementService.confirmDevice( - new QualifiedId(payload.userId, null), //TODO Probably a good place to put the DEFAULT_DOMAIN + new QualifiedId(payload.userId, null), payload.teamId, payload.clientId, payload.refreshToken diff --git a/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java b/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java index 38dfcb9..a68f9e5 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/backend/InitiateResourceV0.java @@ -38,7 +38,7 @@ public Response initiate(@ApiParam @Valid @NotNull InitPayloadV0 init) { try { final InitializedDeviceDTO initializedDeviceDTO = deviceManagementService.initiateLegalHoldDevice( - new QualifiedId(init.userId, null), //TODO Probably a good place to put the DEFAULT_DOMAIN + new QualifiedId(init.userId, null), init.teamId ); diff --git a/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java b/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java index 7068233..10a07cb 100644 --- a/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java +++ b/src/main/java/com/wire/bots/hold/resource/v0/backend/RemoveResourceV0.java @@ -34,7 +34,7 @@ public RemoveResourceV0(DeviceManagementService deviceManagementService) { public Response remove(@ApiParam @Valid InitPayloadV0 payload) { try { deviceManagementService.removeDevice( - new QualifiedId(payload.userId, null), //TODO Probably a good place to put the DEFAULT_DOMAIN + new QualifiedId(payload.userId, null), payload.teamId ); diff --git a/src/main/java/com/wire/bots/hold/service/DeviceManagementService.java b/src/main/java/com/wire/bots/hold/service/DeviceManagementService.java index 8086e2c..1f655f7 100644 --- a/src/main/java/com/wire/bots/hold/service/DeviceManagementService.java +++ b/src/main/java/com/wire/bots/hold/service/DeviceManagementService.java @@ -67,6 +67,7 @@ public InitializedDeviceDTO initiateLegalHoldDevice(QualifiedId userId, UUID tea */ public void confirmDevice(QualifiedId userId, UUID teamId, String clientId, String refreshToken) { int insert = accessDAO.insert(userId.id, + userId.domain, clientId, refreshToken); @@ -95,7 +96,7 @@ public void removeDevice(QualifiedId userId, UUID teamId) throws IOException, Cr try (Crypto crypto = cf.create(userId)) { crypto.purge(); - int removeAccess = accessDAO.disable(userId.id); + int removeAccess = accessDAO.disable(userId.id, userId.domain); Logger.info( "RemoveResource: team: %s, user: %s, removed: %s", diff --git a/src/main/java/com/wire/bots/hold/utils/HoldClientRepo.java b/src/main/java/com/wire/bots/hold/utils/HoldClientRepo.java index db95d1c..1aef735 100644 --- a/src/main/java/com/wire/bots/hold/utils/HoldClientRepo.java +++ b/src/main/java/com/wire/bots/hold/utils/HoldClientRepo.java @@ -7,7 +7,6 @@ import com.wire.xenon.WireClient; import com.wire.xenon.backend.models.QualifiedId; import com.wire.xenon.crypto.Crypto; -import com.wire.xenon.factories.CryptoFactory; import org.jdbi.v3.core.Jdbi; import javax.ws.rs.client.Client; @@ -16,10 +15,10 @@ public class HoldClientRepo { private final Jdbi jdbi; - private final CryptoFactory cf; + private final CryptoDatabaseFactory cf; private final Client httpClient; - public HoldClientRepo(Jdbi jdbi, CryptoFactory cf, Client httpClient) { + public HoldClientRepo(Jdbi jdbi, CryptoDatabaseFactory cf, Client httpClient) { this.jdbi = jdbi; this.cf = cf; this.httpClient = httpClient; diff --git a/src/main/resources/db/migration/V108__add_domain_column_in_access_events.sql b/src/main/resources/db/migration/V108__add_domain_column_in_access_events.sql new file mode 100644 index 0000000..db3da8e --- /dev/null +++ b/src/main/resources/db/migration/V108__add_domain_column_in_access_events.sql @@ -0,0 +1,16 @@ +ALTER TABLE Access +ADD COLUMN userDomain VARCHAR(255) DEFAULT null; + +ALTER TABLE Access +DROP CONSTRAINT IF EXISTS access_pkey; + +ALTER TABLE Access ADD CONSTRAINT access_user_id_user_domain_key UNIQUE (userId, userDomain); + +ALTER TABLE Events +ADD COLUMN userDomain VARCHAR(255) DEFAULT null; + +ALTER TABLE Events +ADD COLUMN conversationDomain VARCHAR(255) DEFAULT null; + +DROP INDEX conversation_id_idx; +CREATE INDEX events_conversation_id_conversation_domain_idx ON Events (conversationId, conversationDomain); diff --git a/src/test/java/com/wire/bots/hold/DatabaseTest.java b/src/test/java/com/wire/bots/hold/DatabaseTest.java index 3bb7f88..5f203ba 100644 --- a/src/test/java/com/wire/bots/hold/DatabaseTest.java +++ b/src/test/java/com/wire/bots/hold/DatabaseTest.java @@ -74,7 +74,7 @@ public void eventsTextMessageTest() throws JsonProcessingException { assert textMessage.getMessageId().equals(message.getMessageId()); - List events = eventsDAO.listAll(convId.id); + List events = eventsDAO.listAllDefaultDomain(convId.id); assert events.size() == 2; } @@ -97,14 +97,14 @@ public void assetsTest() { @Test public void accessTests() { - final UUID userId = UUID.randomUUID(); + final QualifiedId userId = new QualifiedId(UUID.randomUUID(), UUID.randomUUID().toString()); final String clientId = UUID.randomUUID().toString(); final String cookie = "cookie"; final UUID last = UUID.randomUUID(); final String cookie2 = "cookie2"; final String token = "token"; - final int insert = accessDAO.insert(userId, clientId, cookie); + final int insert = accessDAO.insert(userId.id, userId.domain, clientId, cookie); accessDAO.updateLast(userId, last); accessDAO.update(userId, token, cookie2);