Skip to content

Commit

Permalink
WIP(add-domain-table-columns) #WPB-11298
Browse files Browse the repository at this point in the history
  • Loading branch information
spoonman01 committed Oct 9, 2024
1 parent 6f622c0 commit c908614
Show file tree
Hide file tree
Showing 18 changed files with 88 additions and 41 deletions.
33 changes: 21 additions & 12 deletions src/main/java/com/wire/bots/hold/DAO/AccessDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,42 @@
import java.util.UUID;

public interface AccessDAO {
@SqlUpdate("INSERT INTO Access (userId, clientId, cookie, updated, created, enabled) " +
"VALUES (:userId, :clientId, :cookie, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 1) " +
"ON CONFLICT (userId) DO UPDATE SET cookie = EXCLUDED.cookie, clientId = EXCLUDED.clientId, " +
@SqlUpdate("INSERT INTO Access (userId, userDomain, clientId, cookie, updated, created, enabled) " +
"VALUES (:userId, :userDomain, :clientId, :cookie, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP, 1) " +
"ON CONFLICT (userId, userDomain) DO UPDATE SET cookie = EXCLUDED.cookie, clientId = EXCLUDED.clientId, " +
"updated = EXCLUDED.updated, enabled = EXCLUDED.enabled")
int insert(@Bind("userId") UUID userId,
@Bind("userDomain") String userDomain,
@Bind("clientId") String clientId,
@Bind("cookie") String cookie);

@SqlUpdate("UPDATE Access SET enabled = 0, updated = CURRENT_TIMESTAMP WHERE userId = :userId")
int disable(@Bind("userId") UUID userId);
@SqlUpdate("UPDATE Access SET enabled = 0, updated = CURRENT_TIMESTAMP WHERE userId = :userId " +
"AND (( :userDomain IS NULL AND userDomain IS null ) or ( :userDomain IS NOT NULL AND userDomain = :userDomain ))")
int disable(@Bind("userId") UUID userId,
@Bind("userDomain") String userDomain);

@SqlUpdate("UPDATE Access SET token = :token, cookie = :cookie, updated = CURRENT_TIMESTAMP WHERE userId = :userId")
@SqlUpdate("UPDATE Access SET token = :token, cookie = :cookie, updated = CURRENT_TIMESTAMP WHERE userId = :userId " +
"AND (( :userDomain IS NULL AND userDomain IS null ) OR ( :userDomain IS NOT NULL AND userDomain = :userDomain ))")
int update(@Bind("userId") UUID userId,
@Bind("token") String token,
@Bind("cookie") String cookie);
@Bind("userDomain") String userDomain,
@Bind("token") String token,
@Bind("cookie") String cookie);

@SqlUpdate("UPDATE Access SET last = :last, updated = CURRENT_TIMESTAMP WHERE userId = :userId")
@SqlUpdate("UPDATE Access SET last = :last, updated = CURRENT_TIMESTAMP WHERE userId = :userId AND (( :userDomain IS NULL AND userDomain IS null ) " +
"OR ( :userDomain IS NOT NULL AND userDomain = :userDomain ))")
int updateLast(@Bind("userId") UUID userId,
@Bind("last") UUID last);
@Bind("userDomain") String userDomain,
@Bind("last") UUID last);

@SqlQuery("SELECT * FROM Access WHERE token IS NOT NULL AND enabled = 1 ORDER BY created DESC LIMIT 1")
@RegisterColumnMapper(AccessResultSetMapper.class)
LHAccess getSingle();

@SqlQuery("SELECT * FROM Access WHERE userId = :userId")
@SqlQuery("SELECT * FROM Access WHERE userId = :userId AND (( :userDomain IS NULL AND userDomain is null ) " +
"or ( :userDomain is NOT NULL AND userDomain = :userDomain ))")
@RegisterColumnMapper(AccessResultSetMapper.class)
LHAccess get(@Bind("userId") UUID userId);
LHAccess get(@Bind("userId") UUID userId,
@Bind("userDomain") String userDomain);

@SqlQuery("SELECT * FROM Access WHERE enabled = 1 ORDER BY created DESC")
@RegisterColumnMapper(AccessResultSetMapper.class)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.wire.bots.hold.DAO;

import com.wire.bots.hold.model.database.LHAccess;
import com.wire.xenon.backend.models.QualifiedId;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;

Expand All @@ -13,7 +14,9 @@ public class AccessResultSetMapper implements ColumnMapper<LHAccess> {
public LHAccess map(ResultSet rs, int columnNumber, StatementContext ctx) throws SQLException {
LHAccess LHAccess = new LHAccess();
LHAccess.last = (UUID) rs.getObject("last");
LHAccess.userId = (UUID) rs.getObject("userId");
UUID userId = (UUID) rs.getObject("userId");
String userDomain = rs.getString("userDomain");
LHAccess.userId = new QualifiedId(userId, userDomain);
LHAccess.clientId = rs.getString("clientId");
LHAccess.token = rs.getString("token");
LHAccess.cookie = rs.getString("cookie");
Expand Down
20 changes: 16 additions & 4 deletions src/main/java/com/wire/bots/hold/DAO/EventsDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,25 @@ int insert(@Bind("eventId") UUID eventId,
@RegisterColumnMapper(EventsResultSetMapper.class)
Event get(@Bind("eventId") UUID eventId);

@SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId ORDER BY time DESC")
@SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND (conversationDomain is NULL OR conversationDomain = :conversationDomain) ORDER BY time DESC")
@RegisterColumnMapper(EventsResultSetMapper.class)
List<Event> listAll(@Bind("conversationId") UUID conversationId);
List<Event> listAllDefaultDomain(@Bind("conversationId") UUID conversationId,
@Bind("conversationDomain") String conversationDomain);

@SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId ORDER BY time ASC")
@SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND conversationDomain = :conversationDomain ORDER BY time DESC")
@RegisterColumnMapper(EventsResultSetMapper.class)
List<Event> listAllAsc(@Bind("conversationId") UUID conversationId);
List<Event> listAll(@Bind("conversationId") UUID conversationId,
@Bind("conversationDomain") String conversationDomain);

@SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND (conversationDomain is NULL OR conversationDomain = :conversationDomain) ORDER BY time ASC")
@RegisterColumnMapper(EventsResultSetMapper.class)
List<Event> listAllDefaultDomainAsc(@Bind("conversationId") UUID conversationId,
@Bind("conversationDomain") String conversationDomain);

@SqlQuery("SELECT * FROM Events WHERE conversationId = :conversationId AND conversationDomain = :conversationDomain ORDER BY time ASC")
@RegisterColumnMapper(EventsResultSetMapper.class)
List<Event> listAllAsc(@Bind("conversationId") UUID conversationId,
@Bind("conversationDomain") String conversationDomain);

@SqlQuery("SELECT DISTINCT conversationId, MAX(time) AS time " +
"FROM Events " +
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/wire/bots/hold/Service.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.wire.bots.hold.DAO.MetadataDAO;
import com.wire.bots.hold.filters.ServiceAuthenticationFilter;
import com.wire.bots.hold.healthchecks.SanityCheck;
import com.wire.bots.hold.monitoring.ApiVersionResource;
import com.wire.bots.hold.monitoring.RequestMdcFactoryFilter;
import com.wire.bots.hold.monitoring.StatusResource;
import com.wire.bots.hold.resource.v0.audit.*;
Expand Down Expand Up @@ -126,6 +127,7 @@ public void run(Config config, Environment environment) throws ExecutionExceptio

// Monitoring resources
addResource(new StatusResource());
addResource(new ApiVersionResource());
addResource(new RequestMdcFactoryFilter());

// Used by Wire Server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.wire.bots.hold.model.database.LHAccess;
import com.wire.bots.hold.utils.Cache;
import com.wire.helium.API;
import com.wire.xenon.backend.models.QualifiedId;
import com.wire.xenon.tools.Logger;

import javax.ws.rs.client.Client;
Expand Down Expand Up @@ -37,10 +36,11 @@ protected Result check() {
while (!accessList.isEmpty()) {
Logger.info("SanityCheck: checking %d devices, created: %s", accessList.size(), created);
for (LHAccess access : accessList) {
// TODO(WPB-11287): Use user domain if exists, otherwise default
// TODO: String domain = (access.domain != null) ? access.domain : Cache.DEFAULT_DOMAIN;
if (access.userId.domain == null) {
access.userId.domain = Cache.getFallbackDomain();
}
boolean hasDevice = api.hasDevice(
new QualifiedId(access.userId, Cache.getFallbackDomain()), // TODO(WPB-11287): Change null to default domain
access.userId,
access.clientId
);

Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/wire/bots/hold/model/database/Event.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
public class Event {
public UUID eventId;
public UUID conversationId;
public String conversationDomain; // Keeping values split instead of using QualifiedId because of HTML templating
public UUID userId;
public String userDomain; // Keeping values split instead of using QualifiedId because of HTML templating
public String type;
public String payload;
public String time;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.wire.bots.hold.model.database;

import com.wire.xenon.backend.models.QualifiedId;

import java.util.UUID;

public class LHAccess {
public UUID last;
public UUID userId;
public QualifiedId userId;
public String clientId;
public String token;
public String cookie;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import javax.ws.rs.core.Response;

@Api
@Path("/status")
@Path("/{parameter: status|v1/status}")
@Produces(MediaType.TEXT_PLAIN)
public class StatusResource {
@GET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ public ConversationResource(Jdbi jdbi, Client httpClient) {
public Response list(@ApiParam @PathParam("conversationId") UUID conversationId,
@ApiParam @QueryParam("html") boolean isHtml) {
try {
List<Event> events = eventsDAO.listAllAsc(conversationId);

// TODO(WPB-11287) Verify default domain
//TODO Get DEFAULT_DOMAIN, then fetch events with domain = null and domain = DEFAULT_DOMAIN
List<Event> events = eventsDAO.listAllDefaultDomainAsc(conversationId);

testAPI();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import com.wire.bots.hold.DAO.AccessDAO;
import com.wire.bots.hold.filters.ServiceAuthorization;
import com.wire.bots.hold.model.database.LHAccess;
import com.wire.bots.hold.utils.CryptoDatabaseFactory;
import com.wire.xenon.backend.models.QualifiedId;
import com.wire.xenon.crypto.Crypto;
import com.wire.xenon.factories.CryptoFactory;
import com.wire.xenon.tools.Logger;
import io.swagger.annotations.Api;
import io.swagger.annotations.ApiOperation;
Expand All @@ -32,10 +33,10 @@
@Produces(MediaType.TEXT_HTML)
public class DevicesResource {
private final static MustacheFactory mf = new DefaultMustacheFactory();
private final CryptoFactory cryptoFactory;
private final CryptoDatabaseFactory cryptoFactory;
private final AccessDAO accessDAO;

public DevicesResource(AccessDAO accessDAO, CryptoFactory cryptoFactory) {
public DevicesResource(AccessDAO accessDAO, CryptoDatabaseFactory cryptoFactory) {
this.cryptoFactory = cryptoFactory;
this.accessDAO = accessDAO;
}
Expand Down Expand Up @@ -95,7 +96,7 @@ private String execute(Object model) throws IOException {

static class Legal {
UUID last;
UUID userId;
QualifiedId userId;
String clientId;
String fingerprint;
String updated;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ public EventsResource(EventsDAO eventsDAO) {
@ApiResponse(code = 200, message = "Wire events")})
public Response list(@ApiParam @PathParam("conversationId") UUID conversationId) {
try {
//TODO Get DEFAULT_DOMAIN, then fetch events with domain = null and domain = DEFAULT_DOMAIN
Model model = new Model();
model.events = eventsDAO.listAll(conversationId);
model.events = eventsDAO.listAllDefaultDomain(conversationId);
String html = execute(model);

return Response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public ConfirmResourceV0(DeviceManagementService deviceManagementService) {
public Response confirm(@ApiParam @Valid @NotNull ConfirmPayloadV0 payload) {
try {
deviceManagementService.confirmDevice(
new QualifiedId(payload.userId, null), //TODO Probably a good place to put the DEFAULT_DOMAIN
new QualifiedId(payload.userId, null),
payload.teamId,
payload.clientId,
payload.refreshToken
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public Response initiate(@ApiParam @Valid @NotNull InitPayloadV0 init) {
try {
final InitializedDeviceDTO initializedDeviceDTO =
deviceManagementService.initiateLegalHoldDevice(
new QualifiedId(init.userId, null), //TODO Probably a good place to put the DEFAULT_DOMAIN
new QualifiedId(init.userId, null),
init.teamId
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public RemoveResourceV0(DeviceManagementService deviceManagementService) {
public Response remove(@ApiParam @Valid InitPayloadV0 payload) {
try {
deviceManagementService.removeDevice(
new QualifiedId(payload.userId, null), //TODO Probably a good place to put the DEFAULT_DOMAIN
new QualifiedId(payload.userId, null),
payload.teamId
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public InitializedDeviceDTO initiateLegalHoldDevice(QualifiedId userId, UUID tea
*/
public void confirmDevice(QualifiedId userId, UUID teamId, String clientId, String refreshToken) {
int insert = accessDAO.insert(userId.id,
userId.domain,
clientId,
refreshToken);

Expand Down Expand Up @@ -95,7 +96,7 @@ public void removeDevice(QualifiedId userId, UUID teamId) throws IOException, Cr
try (Crypto crypto = cf.create(userId)) {
crypto.purge();

int removeAccess = accessDAO.disable(userId.id);
int removeAccess = accessDAO.disable(userId.id, userId.domain);

Logger.info(
"RemoveResource: team: %s, user: %s, removed: %s",
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/com/wire/bots/hold/utils/HoldClientRepo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.wire.xenon.WireClient;
import com.wire.xenon.backend.models.QualifiedId;
import com.wire.xenon.crypto.Crypto;
import com.wire.xenon.factories.CryptoFactory;
import org.jdbi.v3.core.Jdbi;

import javax.ws.rs.client.Client;
Expand All @@ -16,10 +15,10 @@
public class HoldClientRepo {

private final Jdbi jdbi;
private final CryptoFactory cf;
private final CryptoDatabaseFactory cf;
private final Client httpClient;

public HoldClientRepo(Jdbi jdbi, CryptoFactory cf, Client httpClient) {
public HoldClientRepo(Jdbi jdbi, CryptoDatabaseFactory cf, Client httpClient) {
this.jdbi = jdbi;
this.cf = cf;
this.httpClient = httpClient;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
ALTER TABLE Access
ADD COLUMN userDomain VARCHAR(255) DEFAULT null;

ALTER TABLE Access
DROP CONSTRAINT IF EXISTS access_pkey;

ALTER TABLE Access ADD CONSTRAINT access_user_id_user_domain_key UNIQUE (userId, userDomain);

ALTER TABLE Events
ADD COLUMN userDomain VARCHAR(255) DEFAULT null;

ALTER TABLE Events
ADD COLUMN conversationDomain VARCHAR(255) DEFAULT null;

DROP INDEX conversation_id_idx;
CREATE INDEX events_conversation_id_conversation_domain_idx ON Events (conversationId, conversationDomain);
6 changes: 3 additions & 3 deletions src/test/java/com/wire/bots/hold/DatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void eventsTextMessageTest() throws JsonProcessingException {

assert textMessage.getMessageId().equals(message.getMessageId());

List<Event> events = eventsDAO.listAll(convId.id);
List<Event> events = eventsDAO.listAllDefaultDomain(convId.id);
assert events.size() == 2;
}

Expand All @@ -97,14 +97,14 @@ public void assetsTest() {

@Test
public void accessTests() {
final UUID userId = UUID.randomUUID();
final QualifiedId userId = new QualifiedId(UUID.randomUUID(), UUID.randomUUID().toString());
final String clientId = UUID.randomUUID().toString();
final String cookie = "cookie";
final UUID last = UUID.randomUUID();
final String cookie2 = "cookie2";
final String token = "token";

final int insert = accessDAO.insert(userId, clientId, cookie);
final int insert = accessDAO.insert(userId.id, userId.domain, clientId, cookie);
accessDAO.updateLast(userId, last);
accessDAO.update(userId, token, cookie2);

Expand Down

0 comments on commit c908614

Please sign in to comment.