Skip to content

Commit

Permalink
Add Filter convenience
Browse files Browse the repository at this point in the history
  • Loading branch information
MatKuhr committed Oct 15, 2024
1 parent 0bc695c commit b8414f2
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 90 deletions.
5 changes: 5 additions & 0 deletions orchestration/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.AzureContentSafety;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.experimental.Accessors;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import static com.sap.ai.sdk.orchestration.client.model.AzureThreshold.fromValue;

@Data
@Accessors(fluent = true)
public class AzureContentFilter implements ContentFilter {
@Nullable
private Setting hate;
@Nullable
private Setting selfHarm;
@Nullable
private Setting sexual;
@Nullable
private Setting violence;

@RequiredArgsConstructor
public enum Setting {
VERY_STRICT(0),
STRICT(2),
MODERATE(4),
LENIENT(6);

private final int value;
}

@Nonnull
com.sap.ai.sdk.orchestration.client.model.FilterConfig toFilterConfigDTO() {
var dto = AzureContentSafety.create();
if (hate != null) {
dto.hate(fromValue(hate.value));
}
if (selfHarm != null) {
dto.selfHarm(fromValue(selfHarm.value));
}
if (sexual != null) {
dto.sexual(fromValue(sexual.value));
}
if (violence != null) {
dto.violence(fromValue(violence.value));
}
return com.sap.ai.sdk.orchestration.client.model.FilterConfig.create()
.type(com.sap.ai.sdk.orchestration.client.model.FilterConfig.TypeEnum.AZURE_CONTENT_SAFETY)
.config(dto);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.sap.ai.sdk.orchestration;

public interface ContentFilter {}
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
import com.sap.ai.sdk.orchestration.client.model.FilterConfig;
import com.sap.ai.sdk.orchestration.client.model.FilteringConfig;
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import io.vavr.control.Option;
import java.util.ArrayList;
import java.util.List;

import javax.annotation.Nonnull;
import lombok.AccessLevel;
import lombok.Data;
Expand All @@ -22,9 +16,9 @@ public class DefaultOrchestrationConfig<T extends OrchestrationConfig<T>>

@Nonnull private Option<LLMModuleConfig> llmConfig = Option.none();
@Nonnull private Option<TemplatingModuleConfig> template = Option.none();
@Nonnull private Option<DpiMaskingConfig> maskingConfig = Option.none();
@Nonnull private Option<FilterConfig> inputContentFilter = Option.none();
@Nonnull private Option<FilterConfig> outputContentFilter = Option.none();
@Nonnull private Option<MaskingConfig> maskingConfig = Option.none();
@Nonnull private Option<ContentFilter> inputContentFilter = Option.none();
@Nonnull private Option<ContentFilter> outputContentFilter = Option.none();

@Nonnull
@Override
Expand All @@ -49,69 +43,22 @@ public T withTemplate(@Nonnull final TemplatingModuleConfig template) {

@Nonnull
@Override
public T withMaskingConfig(@Nonnull final DpiMaskingConfig maskingConfig) {
public T withMaskingConfig(@Nonnull final MaskingConfig maskingConfig) {
this.maskingConfig = Option.some(maskingConfig);
return instance();
}

@Nonnull
@Override
public T withInputContentFilter(@Nonnull final FilterConfig filter) {
public T withInputContentFilter(@Nonnull final ContentFilter filter) {
this.inputContentFilter = Option.some(filter);
return instance();
}

@Nonnull
@Override
public T withOutputContentFilter(@Nonnull final FilterConfig filter) {
public T withOutputContentFilter(@Nonnull final ContentFilter filter) {
this.outputContentFilter = Option.some(filter);
return instance();
}

@Nonnull
static ModuleConfigs toModuleConfigDTO(
@Nonnull final OrchestrationConfig<?> config, @Nonnull final List<ChatMessage> messages) {
LLMModuleConfig llm =
config
.getLlmConfig()
.getOrElseThrow(() -> new IllegalStateException("LLM module config is required."));

/*
* Currently, we have to merge the prompt into the template configuration.
* This works around the limitation that the template config isn't optional.
* This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}".
* In this case, the request will fail, since the templating module will try to resolve the parameter.
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
*/
Option<TemplatingModuleConfig> maybeTemplate = config.getTemplate();
final var messagesWithPrompt = new ArrayList<>(messages);
maybeTemplate.map(TemplatingModuleConfig::getTemplate).forEach(messagesWithPrompt::addAll);

if (messagesWithPrompt.isEmpty()) {
throw new IllegalStateException(
"A template is required. Pass at least one message or configure the templating module.");
}
TemplatingModuleConfig template = TemplatingModuleConfig.create().template(messagesWithPrompt);
maybeTemplate.map(TemplatingModuleConfig::getDefaults).forEach(template::defaults);

ModuleConfigs dto =
ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template);

config
.getMaskingConfig()
.map(DpiMaskingConfig::toMaskingModuleDTO)
.forEach(dto::maskingModuleConfig);

var maybeInputFilter = config.getInputContentFilter();
var maybeOutputFilter = config.getOutputContentFilter();

if (maybeInputFilter.isDefined() || maybeOutputFilter.isDefined()) {
var filter = FilteringModuleConfig.create();
maybeInputFilter.map(it -> FilteringConfig.create().filters(it)).forEach(filter::input);
maybeOutputFilter.map(it -> FilteringConfig.create().filters(it)).forEach(filter::output);
dto = dto.filteringModuleConfig(filter);
}

return dto;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

@Value
@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
public class DpiMaskingConfig {
// TODO: Create an interface to represent the "oneOf" for masking providers?
public class DpiMaskingConfig implements MaskingConfig {
@Nonnull MaskingProviderConfig.MethodEnum maskingMethod;
@Nonnull List<DPIEntities> entities;

Expand All @@ -32,14 +31,13 @@ public static Builder forPseudonymization() {
}

@Nonnull
MaskingModuleConfig toMaskingModuleDTO() {
MaskingProviderConfig toMaskingProviderDTO() {
var entities = this.entities.stream().map(it -> DPIEntityConfig.create().type(it)).toList();
var provider =
return
MaskingProviderConfig.create()
.type(SAP_DATA_PRIVACY_INTEGRATION)
.method(maskingMethod)
.entities(entities);
return MaskingModuleConfig.create().maskingProviders(provider);
}

@RequiredArgsConstructor(access = AccessLevel.PRIVATE)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.sap.ai.sdk.orchestration;

public interface MaskingConfig {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.ChatMessage;
import com.sap.ai.sdk.orchestration.client.model.FilteringConfig;
import com.sap.ai.sdk.orchestration.client.model.FilteringModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.MaskingModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.ModuleConfigs;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import io.vavr.control.Option;

import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.List;

class ModuleConfigFactory {
@Nonnull
static ModuleConfigs toModuleConfigDTO(
@Nonnull final OrchestrationConfig<?> config, @Nonnull final List<ChatMessage> messages) {
LLMModuleConfig llm =
config
.getLlmConfig()
.getOrElseThrow(() -> new IllegalStateException("LLM module config is required"));

/*
* Currently, we have to merge the prompt into the template configuration.
* This works around the limitation that the template config isn't optional.
* This comes at the risk that the prompt unintentionally contains the templating pattern "{{? .. }}".
* In this case, the request will fail, since the templating module will try to resolve the parameter.
* To be fixed with https://github.tools.sap/AI/llm-orchestration/issues/662
*/
Option<TemplatingModuleConfig> maybeTemplate = config.getTemplate();
final var messagesWithPrompt = new ArrayList<>(messages);
maybeTemplate.map(TemplatingModuleConfig::getTemplate).forEach(messagesWithPrompt::addAll);

if (messagesWithPrompt.isEmpty()) {
throw new IllegalStateException(
"A prompt is required. Pass at least one message or configure the templating module.");
}
TemplatingModuleConfig template = TemplatingModuleConfig.create().template(messagesWithPrompt);
maybeTemplate.map(TemplatingModuleConfig::getDefaults).forEach(template::defaults);

ModuleConfigs dto =
ModuleConfigs.create().llmModuleConfig(llm).templatingModuleConfig(template);

config
.getMaskingConfig()
.filter(DpiMaskingConfig.class::isInstance)
.map(DpiMaskingConfig.class::cast)
.map(DpiMaskingConfig::toMaskingProviderDTO)
.map(it -> MaskingModuleConfig.create().maskingProviders(it))
.forEach(dto::maskingModuleConfig);

var maybeInputFilter = config.getInputContentFilter();
var maybeOutputFilter = config.getOutputContentFilter();

if (maybeInputFilter.isDefined() || maybeOutputFilter.isDefined()) {
var filter = FilteringModuleConfig.create();
maybeInputFilter
.filter(AzureContentFilter.class::isInstance)
.map(AzureContentFilter.class::cast)
.map(AzureContentFilter::toFilterConfigDTO)
.map(it -> FilteringConfig.create().filters(it)).forEach(filter::input);
maybeOutputFilter
.filter(AzureContentFilter.class::isInstance)
.map(AzureContentFilter.class::cast)
.map(AzureContentFilter::toFilterConfigDTO)
.map(it -> FilteringConfig.create().filters(it)).forEach(filter::output);
dto = dto.filteringModuleConfig(filter);
}

return dto;
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.sap.ai.sdk.orchestration;

import com.sap.ai.sdk.orchestration.client.model.FilterConfig;
import com.sap.ai.sdk.orchestration.client.model.LLMModuleConfig;
import com.sap.ai.sdk.orchestration.client.model.TemplatingModuleConfig;
import io.vavr.control.Option;
Expand All @@ -23,13 +22,13 @@ public interface OrchestrationConfig<T extends OrchestrationConfig<T>> {
Option<TemplatingModuleConfig> getTemplate();

@Nonnull
Option<DpiMaskingConfig> getMaskingConfig();
Option<MaskingConfig> getMaskingConfig();

@Nonnull
Option<FilterConfig> getInputContentFilter();
Option<ContentFilter> getInputContentFilter();

@Nonnull
Option<FilterConfig> getOutputContentFilter();
Option<ContentFilter> getOutputContentFilter();

@Nonnull
T withLlmConfig(@Nonnull final LLMModuleConfig llm);
Expand All @@ -38,13 +37,13 @@ public interface OrchestrationConfig<T extends OrchestrationConfig<T>> {
T withTemplate(@Nonnull final TemplatingModuleConfig template);

@Nonnull
T withMaskingConfig(@Nonnull final DpiMaskingConfig maskingConfig);
T withMaskingConfig(@Nonnull final MaskingConfig maskingConfig);

@Nonnull
T withInputContentFilter(@Nonnull final FilterConfig filter);
T withInputContentFilter(@Nonnull final ContentFilter filter);

@Nonnull
T withOutputContentFilter(@Nonnull final FilterConfig filter);
T withOutputContentFilter(@Nonnull final ContentFilter filter);

/**
* Copy the configuration into the given target configuration. The copy is
Expand All @@ -59,7 +58,7 @@ public interface OrchestrationConfig<T extends OrchestrationConfig<T>> {
*
* @param source The source configuration to copy from.
*/
default T copyFrom(@Nonnull final OrchestrationConfig<?> source) {
default T copyOrchestrationConfigurationFrom(@Nonnull final OrchestrationConfig<?> source) {
getLlmConfig().orElse(source::getLlmConfig).forEach(this::withLlmConfig);
getTemplate().orElse(source::getTemplate).forEach(this::withTemplate);
getMaskingConfig().orElse(source::getMaskingConfig).forEach(this::withMaskingConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public OrchestrationPrompt(@Nonnull final Map<String, String> inputParams) {
@Nonnull
ModuleConfigs toModuleConfigDTO(@Nonnull final OrchestrationConfig<?> defaults) {
// duplicate the prompt config so it isn't modified, to make sure this prompt can be reused
var config = new DefaultOrchestrationConfig<>().copyFrom(this).copyFrom(defaults);
return DefaultOrchestrationConfig.toModuleConfigDTO(config, messages);
var config = new DefaultOrchestrationConfig<>().copyOrchestrationConfigurationFrom(this).copyOrchestrationConfigurationFrom(defaults);
return ModuleConfigFactory.toModuleConfigDTO(config, messages);
}
}
Loading

0 comments on commit b8414f2

Please sign in to comment.