Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,69 @@

import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonTypeInfo;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/** Helper class to describe a {@link Resource} */
public class ResourceDescriptor {
private static final String FIELD_CLAZZ = "clazz";
private static final String FIELD_INITIAL_ARGUMENTS = "initialArguments";
private static final String FIELD_CLAZZ = "target_clazz";
private static final String FIELD_MODULE = "target_module";
private static final String FIELD_INITIAL_ARGUMENTS = "arguments";

@JsonProperty(FIELD_CLAZZ)
private final String clazz;

// TODO: support nested map/list with non primitive value.
@JsonProperty(FIELD_MODULE)
private final String module;

@JsonProperty(FIELD_INITIAL_ARGUMENTS)
@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
private final Map<String, Object> initialArguments;

/**
* Initialize ResourceDescriptor.
*
* <p>Creates a new ResourceDescriptor with the specified class information and initial
* arguments. This constructor supports cross-platform compatibility between Java and Python
* resources.
*
* @param clazz The class identifier for the resource. Its meaning depends on the resource type:
* <ul>
* <li><b>For Java resources:</b> The fully qualified Java class name (e.g.,
* "com.example.YourJavaClass"). The {@code module} parameter should be empty or null.
* <li><b>For Python resources (when declaring from Java):</b> The Python class name
* (simple name, not module path, e.g., "YourPythonClass"). The Python module path
* must be specified in the {@code module} parameter (e.g., "your_module.submodule").
* </ul>
*
* @param module The Python module path for cross-platform compatibility. Defaults to empty
* string for Java resources. Example: "your_module.submodule"
* @param initialArguments Additional arguments for resource initialization. Can be null or
* empty map if no initial arguments are needed.
*/
@JsonCreator
public ResourceDescriptor(
@JsonProperty(FIELD_MODULE) String module,
@JsonProperty(FIELD_CLAZZ) String clazz,
@JsonProperty(FIELD_INITIAL_ARGUMENTS) Map<String, Object> initialArguments) {
this.clazz = clazz;
this.module = module;
this.initialArguments = initialArguments;
}

public ResourceDescriptor(String clazz, Map<String, Object> initialArguments) {
this("", clazz, initialArguments);
}

public String getClazz() {
return clazz;
}

public String getModule() {
return module;
}

public Map<String, Object> getInitialArguments() {
return initialArguments;
}
Expand All @@ -64,6 +97,27 @@ public <T> T getArgument(String argName, T defaultValue) {
return value != null ? value : defaultValue;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}

if (o == null || getClass() != o.getClass()) {
return false;
}

ResourceDescriptor that = (ResourceDescriptor) o;
return Objects.equals(this.clazz, that.clazz)
&& Objects.equals(this.module, that.module)
&& Objects.equals(this.initialArguments, that.initialArguments);
}

@Override
public int hashCode() {
return Objects.hash(clazz, module, initialArguments);
}

public static class Builder {
private final String clazz;
private final Map<String, Object> initialArguments;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ public static ResourceDescriptor chatModelConnection() {
@ChatModelSetup
public static ResourceDescriptor chatModel() {
return ResourceDescriptor.Builder.newBuilder(PythonChatModelSetup.class.getName())
.addInitialArgument("connection", "chatModelConnection")
.addInitialArgument(
"module", "flink_agents.integrations.chat_models.ollama_chat_model")
.addInitialArgument("clazz", "OllamaChatModelSetup")
.addInitialArgument("connection", "chatModelConnection")
.addInitialArgument("model", OLLAMA_MODEL)
.addInitialArgument(
"tools",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.agents.plan.resource.python;

import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.prompt.Prompt;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* PythonPrompt is a subclass of Prompt that provides a method to parse a Python prompt from a
* serialized map.
*/
public class PythonPrompt extends Prompt {
public PythonPrompt(String template) {
super(template);
}

public PythonPrompt(List<ChatMessage> template) {
super(template);
}

public static PythonPrompt fromSerializedMap(Map<String, Object> serialized) {
if (serialized == null || !serialized.containsKey("template")) {
throw new IllegalArgumentException("Map must contain 'template' key");
}

Object templateObj = serialized.get("template");
if (templateObj instanceof String) {
return new PythonPrompt((String) templateObj);
} else if (templateObj instanceof List) {
List<?> templateList = (List<?>) templateObj;
if (templateList.isEmpty()) {
throw new IllegalArgumentException("Template list cannot be empty");
}

List<ChatMessage> messages = new ArrayList<>();
for (Object item : templateList) {
if (!(item instanceof Map)) {
throw new IllegalArgumentException("Each template item must be a Map");
}

Map<String, Object> messageMap = (Map<String, Object>) item;
ChatMessage chatMessage = parseChatMessage(messageMap);
messages.add(chatMessage);
}

return new PythonPrompt(messages);
}
throw new IllegalArgumentException(
"Python prompt parsing failed. Template is not a string or list.");
}

/** Parse a single ChatMessage from a Map representation. */
@SuppressWarnings("unchecked")
private static ChatMessage parseChatMessage(Map<String, Object> messageMap) {
String roleValue = messageMap.get("role").toString();
MessageRole role = MessageRole.fromValue(roleValue);

Object contentObj = messageMap.get("content");
String content = contentObj != null ? contentObj.toString() : "";

List<Map<String, Object>> toolCalls =
(List<Map<String, Object>>) messageMap.get("tool_calls");

Map<String, Object> extraArgs = (Map<String, Object>) messageMap.get("extra_args");

return new ChatMessage(role, content, toolCalls, extraArgs);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.agents.plan.resource.python;

import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.agents.api.tools.ToolMetadata;
import org.apache.flink.agents.api.tools.ToolParameters;
import org.apache.flink.agents.api.tools.ToolResponse;
import org.apache.flink.agents.api.tools.ToolType;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;

import java.util.Map;

/**
* PythonTool is a subclass of Tool that that provides a method to parse a Python tool metadata from
* a serialized map.
*/
public class PythonTool extends Tool {
protected PythonTool(ToolMetadata metadata) {
super(metadata);
}

@SuppressWarnings("unchecked")
public static PythonTool fromSerializedMap(Map<String, Object> serialized)
throws JsonProcessingException {
if (serialized == null) {
throw new IllegalArgumentException("Serialized map cannot be null");
}

if (!serialized.containsKey("metadata")) {
throw new IllegalArgumentException("Map must contain 'metadata' key");
}

Object metadataObj = serialized.get("metadata");
if (!(metadataObj instanceof Map)) {
throw new IllegalArgumentException("'metadata' must be a Map");
}

Map<String, Object> metadata = (Map<String, Object>) metadataObj;

if (!metadata.containsKey("name")) {
throw new IllegalArgumentException("Metadata must contain 'name' key");
}

if (!metadata.containsKey("description")) {
throw new IllegalArgumentException("Metadata must contain 'description' key");
}

if (!metadata.containsKey("args_schema")) {
throw new IllegalArgumentException("Metadata must contain 'args_schema' key");
}

String name = (String) metadata.get("name");
String description = (String) metadata.get("description");

if (name == null) {
throw new IllegalArgumentException("'name' cannot be null");
}

if (description == null) {
throw new IllegalArgumentException("'description' cannot be null");
}

ObjectMapper mapper = new ObjectMapper();
String inputSchema = mapper.writeValueAsString(metadata.get("args_schema"));
return new PythonTool(new ToolMetadata(name, description, inputSchema));
}

@Override
public ToolType getToolType() {
return ToolType.REMOTE_FUNCTION;
}

@Override
public ToolResponse call(ToolParameters parameters) {
throw new UnsupportedOperationException("PythonTool does not support call method.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ public JavaResourceProvider(String name, ResourceType type, ResourceDescriptor d
@Override
public Resource provide(BiFunction<String, ResourceType, Resource> getResource)
throws Exception {
Class<?> clazz = Class.forName(descriptor.getClazz());
String clazzName;
if (descriptor.getModule() == null || descriptor.getModule().isEmpty()) {
clazzName = descriptor.getClazz();
} else {
clazzName = descriptor.getInitialArguments().remove("java_clazz").toString();
}
Class<?> clazz = Class.forName(clazzName);
Constructor<?> constructor =
clazz.getConstructor(ResourceDescriptor.class, BiFunction.class);
return (Resource) constructor.newInstance(descriptor, getResource);
Expand Down
Loading