Skip to content

Commit 1aebc9a

Browse files
committed
added tests
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 673b239 commit 1aebc9a

File tree

7 files changed

+589
-6
lines changed

7 files changed

+589
-6
lines changed

common/src/main/java/org/opensearch/ml/common/transport/agent/MLAgentUpdateRequest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ public MLAgentUpdateRequest(StreamInput in) throws IOException {
4949
@Override
5050
public ActionRequestValidationException validate() {
5151
ActionRequestValidationException exception = null;
52-
if (mlAgent == null) {
53-
exception = addValidationError("ML agent can't be null", exception);
52+
if (agentId == null || mlAgent == null) {
53+
exception = addValidationError("Agent ID and ML Agent cannot be null", exception);
5454
}
55-
5655
return exception;
5756
}
5857

5958
@Override
6059
public void writeTo(StreamOutput out) throws IOException {
6160
super.writeTo(out);
61+
out.writeString(this.agentId);
6262
this.mlAgent.writeTo(out);
6363
}
6464

@@ -73,7 +73,7 @@ public static MLAgentUpdateRequest fromActionRequest(ActionRequest actionRequest
7373
return new MLAgentUpdateRequest(input);
7474
}
7575
} catch (IOException e) {
76-
throw new UncheckedIOException("failed to parse ActionRequest into MLAgentUpdateRequest", e);
76+
throw new UncheckedIOException("Failed to parse ActionRequest into MLAgentUpdateRequest", e);
7777
}
7878
}
7979
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.agent;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
11+
import org.junit.Test;
12+
13+
public class MLAgentUpdateActionTests {
14+
15+
@Test
16+
public void testInstance() {
17+
assertNotNull(MLAgentUpdateAction.INSTANCE);
18+
assertEquals("cluster:admin/opensearch/ml/agents/update", MLAgentUpdateAction.NAME);
19+
}
20+
21+
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.agent;
7+
8+
import static org.junit.Assert.*;
9+
10+
import java.io.IOException;
11+
import java.io.UncheckedIOException;
12+
import java.util.Collections;
13+
14+
import org.junit.Before;
15+
import org.junit.Test;
16+
import org.opensearch.action.ActionRequest;
17+
import org.opensearch.action.ActionRequestValidationException;
18+
import org.opensearch.common.io.stream.BytesStreamOutput;
19+
import org.opensearch.core.common.io.stream.StreamOutput;
20+
import org.opensearch.ml.common.agent.MLAgent;
21+
import org.opensearch.ml.common.agent.MLToolSpec;
22+
23+
public class MLAgentUpdateRequestTests {
24+
25+
String agentId;
26+
MLAgent mlAgent;
27+
28+
@Before
29+
public void setUp() {
30+
agentId = "test_agent_id";
31+
mlAgent = MLAgent
32+
.builder()
33+
.name("test_agent")
34+
.appType("test_app")
35+
.type("flow")
36+
.tools(Collections.singletonList(MLToolSpec.builder().type("ListIndexTool").build()))
37+
.build();
38+
}
39+
40+
@Test
41+
public void constructor_Agent() {
42+
MLAgentUpdateRequest mlAgentUpdateRequest = new MLAgentUpdateRequest(agentId, mlAgent);
43+
assertEquals(agentId, mlAgentUpdateRequest.getAgentId());
44+
assertEquals(mlAgent, mlAgentUpdateRequest.getMlAgent());
45+
46+
ActionRequestValidationException validationException = mlAgentUpdateRequest.validate();
47+
assertNull(validationException);
48+
}
49+
50+
@Test
51+
public void constructor_NullId() {
52+
MLAgentUpdateRequest mlAgentUpdateRequest = new MLAgentUpdateRequest(null, mlAgent);
53+
assertNull(mlAgentUpdateRequest.getAgentId());
54+
55+
ActionRequestValidationException validationException = mlAgentUpdateRequest.validate();
56+
assertNotNull(validationException);
57+
assertTrue(validationException.toString().contains("Agent ID and ML Agent cannot be null"));
58+
}
59+
60+
@Test
61+
public void constructor_NullAgent() {
62+
MLAgentUpdateRequest mlAgentUpdateRequest = new MLAgentUpdateRequest(agentId, null);
63+
assertNull(mlAgentUpdateRequest.getMlAgent());
64+
65+
ActionRequestValidationException validationException = mlAgentUpdateRequest.validate();
66+
assertNotNull(validationException);
67+
assertTrue(validationException.toString().contains("Agent ID and ML Agent cannot be null"));
68+
}
69+
70+
@Test
71+
public void writeTo_Success() throws IOException {
72+
MLAgentUpdateRequest mlAgentUpdateRequest = new MLAgentUpdateRequest(agentId, mlAgent);
73+
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
74+
mlAgentUpdateRequest.writeTo(bytesStreamOutput);
75+
MLAgentUpdateRequest parsedRequest = new MLAgentUpdateRequest(bytesStreamOutput.bytes().streamInput());
76+
assertEquals(agentId, parsedRequest.getAgentId());
77+
assertEquals(mlAgent, parsedRequest.getMlAgent());
78+
}
79+
80+
@Test
81+
public void fromActionRequest_Success() {
82+
MLAgentUpdateRequest mlAgentUpdateRequest = new MLAgentUpdateRequest(agentId, mlAgent);
83+
ActionRequest actionRequest = new ActionRequest() {
84+
@Override
85+
public ActionRequestValidationException validate() {
86+
return null;
87+
}
88+
89+
@Override
90+
public void writeTo(StreamOutput out) throws IOException {
91+
mlAgentUpdateRequest.writeTo(out);
92+
}
93+
};
94+
MLAgentUpdateRequest parsedRequest = MLAgentUpdateRequest.fromActionRequest(actionRequest);
95+
assertNotSame(mlAgentUpdateRequest, parsedRequest);
96+
assertEquals(mlAgentUpdateRequest.getAgentId(), parsedRequest.getAgentId());
97+
assertEquals(mlAgentUpdateRequest.getMlAgent(), parsedRequest.getMlAgent());
98+
}
99+
100+
@Test
101+
public void fromActionRequest_Success_MLAgentUpdateRequest() {
102+
MLAgentUpdateRequest mlAgentUpdateRequest = new MLAgentUpdateRequest(agentId, mlAgent);
103+
MLAgentUpdateRequest parsedRequest = MLAgentUpdateRequest.fromActionRequest(mlAgentUpdateRequest);
104+
assertSame(mlAgentUpdateRequest, parsedRequest);
105+
}
106+
107+
@Test(expected = UncheckedIOException.class)
108+
public void fromActionRequest_IOException() {
109+
ActionRequest actionRequest = new ActionRequest() {
110+
@Override
111+
public ActionRequestValidationException validate() {
112+
return null;
113+
}
114+
115+
@Override
116+
public void writeTo(StreamOutput out) throws IOException {
117+
throw new IOException();
118+
}
119+
};
120+
MLAgentUpdateRequest.fromActionRequest(actionRequest);
121+
}
122+
}

plugin/src/main/java/org/opensearch/ml/action/agents/UpdateAgentTransportAction.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
import org.opensearch.transport.TransportService;
4343
import org.opensearch.transport.client.Client;
4444

45+
import com.google.common.annotations.VisibleForTesting;
46+
4547
import lombok.extern.log4j.Log4j2;
4648

4749
@Log4j2
@@ -83,7 +85,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Update
8385
return;
8486
}
8587

86-
boolean isSuperAdmin = RestActionUtils.isSuperAdminUser(clusterService, client);
88+
boolean isSuperAdmin = isSuperAdminUserWrapper(clusterService, client);
8789

8890
FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY);
8991
GetDataObjectRequest getDataObjectRequest = GetDataObjectRequest
@@ -168,4 +170,9 @@ private void updateAgent(String agentId, MLAgent agent, ActionListener<UpdateRes
168170
}
169171
});
170172
}
173+
174+
@VisibleForTesting
175+
boolean isSuperAdminUserWrapper(ClusterService clusterService, Client client) {
176+
return RestActionUtils.isSuperAdminUser(clusterService, client);
177+
}
171178
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateAgentAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.rest.action.RestToXContentListener;
2626
import org.opensearch.transport.client.node.NodeClient;
2727

28+
import com.google.common.annotations.VisibleForTesting;
2829
import com.google.common.collect.ImmutableList;
2930

3031
public class RestMLUpdateAgentAction extends BaseRestHandler {
@@ -62,7 +63,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
6263
* @param request RestRequest
6364
* @return MLAgentUpdateRequest
6465
*/
65-
private MLAgentUpdateRequest getRequest(RestRequest request) throws IOException {
66+
@VisibleForTesting
67+
MLAgentUpdateRequest getRequest(RestRequest request) throws IOException {
6668
if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
6769
throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG);
6870
}

0 commit comments

Comments
 (0)