Skip to content

Commit

Permalink
[FEATURE] AI Services: Option to dynamically select system message de…
Browse files Browse the repository at this point in the history
…pending on the user message #956
  • Loading branch information
KaisNeffati committed Apr 23, 2024
1 parent 608f55b commit bc217c6
Show file tree
Hide file tree
Showing 11 changed files with 1,738 additions and 1 deletion.
78 changes: 78 additions & 0 deletions docs/docs/tutorials/5-ai-services.md
Expand Up @@ -129,6 +129,84 @@ Friend friend = AiServices.builder(Friend.class)
```
As you can see, you can provide different system messages based on a chat memory ID (user or conversation).


## @RegisterSystemSpecs

Now, let's explore a more complex scenario. In this case, we'll entrust the AI with the task of selecting the appropriate SystemMessage.

This is accomplished by associating multiple system message specifications, denoted by the `@SystemSpec` annotations, with a single `@RegisterSystemSpecs` annotation.

Here is an example :
```java
interface Professor {

// Chat method that takes a user message and returns an AI-generated response based on the specified system specifications.
// SystemSpec annotations define the context in which the AI should generate responses for physics and math questions.
@RegisterSystemSpecs({
@SystemSpec(name = "physics", description = "Good for answering questions about physics", template = {
"You are a very smart physics professor. You are great at answering questions about physics in a concise and easy to understand manner. When you don't know the answer to a question you admit that you don't know.\n" +
"\n" +
"Here is a question:"
}),
@SystemSpec(name = "math", description = "Good for answering math questions", template = {
"You are a very good mathematician. You are great at answering math questions. You are so good because you are able to break down hard problems into their component parts, answer the component parts, and then put them together to answer the broader question.\n" +
"\n" +
"Here is a question:"
})
})
String chat(String userMessage);
}

// Create a Professor instance using AiServices
Professor professor = AiServices.create(Professor.class, model);

// Invoke the chat method with a specific question
String answer = Professor.chat("What is the speed of light?");
// Answer: The speed of light in a vacuum is approximately 299,792 kilometers per second (km/s)...
```
In this example, the `@RegisterSystemSpecs` annotation is used in conjunction with multiple `@SystemSpec` annotations. Here’s a breakdown of the attributes for each `@SystemSpec`:

- `name`: This attribute assigns a unique identifier to the system message and is required.
- `description`: This required attribute offers a concise explanation of the system message's intent.
- `template`: This mandatory attribute consists of an array of strings that outline the template for creating system messages.
- `delimiter`: This optional attribute determines the string used to concatenate elements within the template, with a default setting of a newline ("\n").

Internally, a request is made to the AI model to determine the most suitable system to address the user's query. If none of the specified system specifications align with the user's question, a default message is generated. This message prompts for the specification of a relevant expert and is sent to the language model along with the user's message.

### System Specification Provider
Registering system specifications can also be defined dynamically with the system spec provider:
```java

// Initialize the list of SystemSpecs
List<SystemSpec> systemSpecs = new ArrayList<>();

// Add a SystemSpec for physics
systemSpecs.add(
SystemSpec.builder()
.name("physics")
.description("Good for answering questions about physics")
.template(...) // Add the template details here
.build()
);

// Add a SystemSpec for math
systemSpecs.add(
SystemSpec.builder()
.name("math")
.description("Good for answering math questions")
.template(...) // Add the template details here
.build()
);

// Configure and build the Professor instance with a chat language model and a system spec provider
Professor professor = AiServices.builder(Professor.class)
.chatLanguageModel(model)
.systemSpecProvider(chatMemoryId -> systemSpecs)
.build();

```
As you can see, you can provide different system specs based on a chat memory ID (user or conversation).

## @UserMessage

Now, let's assume the model we use does not support system messages,
Expand Down
Expand Up @@ -8,6 +8,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static dev.langchain4j.internal.Exceptions.runtime;
import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;
Expand All @@ -23,16 +25,25 @@ public class ChatModelMock implements ChatLanguageModel {

private final String staticResponse;
private final RuntimeException exception;
private final Map<String, String> staticResponseMap;
private final List<List<ChatMessage>> requests = synchronizedList(new ArrayList<>());

public ChatModelMock(String staticResponse) {
this.staticResponse = ensureNotBlank(staticResponse, "staticResponse");
this.exception = null;
this.staticResponseMap = null;
}

public ChatModelMock(RuntimeException exception) {
this.staticResponse = null;
this.exception = ensureNotNull(exception, "exception");
this.staticResponseMap = null;
}

public ChatModelMock(Map<String, String> staticResponseMap) {
this.staticResponse = null;
this.exception = null;
this.staticResponseMap = ensureNotNull(staticResponseMap, "staticResponseMap");
}

@Override
Expand All @@ -43,6 +54,12 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
throw exception;
}

if (staticResponseMap != null) {
String key = messages.stream().map(ChatMessage::text)
.collect(Collectors.joining("\n"));
return Response.from(AiMessage.from(staticResponseMap.get(key)));
}

return Response.from(AiMessage.from(staticResponse));
}

Expand All @@ -68,6 +85,10 @@ public static ChatModelMock thatAlwaysResponds(String response) {
return new ChatModelMock(response);
}

public static ChatModelMock thatAlwaysRespondsMap(Map<String, String> responseMap) {
return new ChatModelMock(responseMap);
}

public static ChatModelMock thatAlwaysThrowsException() {
return thatAlwaysThrowsExceptionWithMessage("Something went wrong, but this is an expected exception");
}
Expand Down
61 changes: 61 additions & 0 deletions langchain4j/src/main/java/dev/langchain4j/model/SystemSpec.java
@@ -0,0 +1,61 @@
package dev.langchain4j.model;

import lombok.Builder;
import lombok.Getter;

import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* The {@code SystemSpec} class encapsulates the specifications of a system message.
* It holds essential details that define the system's operational blueprint, including its name,
* description, and a structured message template. This class is designed to be immutable,
* ensuring that once an instance is created, its state cannot be altered.
*
* <p>Attributes:</p>
* <ul>
* <li>{@code name} - Represents the name of the system message, essential for identification.</li>
* <li>{@code description} - Provides a brief description of what the system message is designed to do.</li>
* <li>{@code template} - An array of strings that outlines the template for system messages.</li>
* <li>{@code delimiter} - A string used to separate elements in the template, with a default value of newline ("\n").</li>
* </ul>
*
* <p>Instances of this class are created using a builder pattern, allowing for a flexible and clear construction process.</p>
*
* <p>Usage example:</p>
* <pre>
* SystemSpec spec = SystemSpec.builder()
* .name("ExampleSystem")
* .description("This system handles data processing.")
* .template(new String[]{"Initiate", "Process", "Terminate"})
* .delimiter(",")
* .build();
* </pre>
*
*/
@Getter
public class SystemSpec {
private final String name;
private final String description;
private final String[] template;
private final String delimiter;


/**
* Constructs an instance of {@code SystemSpec} using the provided parameters.
* This constructor enforces non-null values for name, description, and template through the {@code ensureNotNull} method.
* If the {@code delimiter} is not provided, it defaults to a newline ("\n").
*
* @param name the name of the system, must not be null.
* @param description a brief description of the system, must not be null.
* @param template an array representing the structured message or operation template, must not be null.
* @param delimiter the delimiter used to separate template elements, defaults to newline if null.
* @throws NullPointerException if name, description, or template is null.
*/
@Builder
SystemSpec(String name, String description, String[] template, String delimiter) {
this.name = ensureNotNull(name, "name");
this.description = ensureNotNull(description, "description");
this.template = ensureNotNull(template, "template");
this.delimiter = delimiter != null ? delimiter : "\n";
}
}
Expand Up @@ -4,6 +4,7 @@
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.SystemSpec;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.moderation.ModerationModel;
Expand All @@ -17,6 +18,7 @@
public class AiServiceContext {

private static final Function<Object, Optional<String>> DEFAULT_MESSAGE_PROVIDER = x -> Optional.empty();
private static final Function<Object, Optional<List<SystemSpec>>> DEFAULT_SPEC_PROVIDER = x -> Optional.empty();

public final Class<?> aiServiceClass;

Expand All @@ -35,6 +37,12 @@ public class AiServiceContext {

public Function<Object, Optional<String>> systemMessageProvider = DEFAULT_MESSAGE_PROVIDER;

public boolean isSystemMessageProviderEnabled = false;

public Function<Object, Optional<List<SystemSpec>>> systemSpecProvider = DEFAULT_SPEC_PROVIDER;

public boolean isSystemSpecProviderEnabled = false;

public AiServiceContext(Class<?> aiServiceClass) {
this.aiServiceClass = aiServiceClass;
}
Expand Down
22 changes: 22 additions & 0 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Expand Up @@ -9,6 +9,7 @@
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.SystemSpec;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.input.structured.StructuredPrompt;
Expand Down Expand Up @@ -222,6 +223,27 @@ public AiServices<T> streamingChatLanguageModel(StreamingChatLanguageModel strea
*/
public AiServices<T> systemMessageProvider(Function<Object, String> systemMessageProvider) {
context.systemMessageProvider = systemMessageProvider.andThen(Optional::ofNullable);
context.isSystemMessageProviderEnabled = true;
return this;
}

/**
* Configures the system specification provider function within the AI services context.
* This provider selects a system specification from a provided list, based on the context of the user message.
* <br>
* When both {@code @RegisterSystemSpec} and the system spec provider are configured,
* {@code @RegisterSystemSpec} takes precedence.
*
* @param systemSpecProvider the system specification provider function
* that takes an object as input and returns a list
* of system specifications. The function should
* handle null input and return an optional list
* of system specifications.
* @return builder
*/
public AiServices<T> systemSpecProvider(Function<Object, List<SystemSpec>> systemSpecProvider) {
context.systemSpecProvider = systemSpecProvider.andThen(Optional::ofNullable);
context.isSystemSpecProviderEnabled = true;
return this;
}

Expand Down
Expand Up @@ -14,19 +14,22 @@
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.query.Metadata;
import org.jetbrains.annotations.NotNull;

import java.io.InputStream;
import java.lang.reflect.*;
import java.util.*;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;

import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.Exceptions.runtime;
import static dev.langchain4j.service.ServiceOutputParser.outputFormatInstructions;
import static dev.langchain4j.service.ServiceOutputParser.parse;
import static dev.langchain4j.service.SystemSpecService.fetchSystemSpecUsingAI;

class DefaultAiServices<T> extends AiServices<T> {

Expand Down Expand Up @@ -86,8 +89,14 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

Object memoryId = findMemoryId(method, args).orElse(DEFAULT);

Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);

UserMessage userMessage = prepareUserMessage(method, args);
Optional<SystemMessage> systemMessage = Optional.empty();
if (method.isAnnotationPresent(dev.langchain4j.service.SystemMessage.class) || context.isSystemMessageProviderEnabled) {
systemMessage = prepareSystemMessage(memoryId, method, args);
} else if (method.isAnnotationPresent(RegisterSystemSpecs.class) || context.isSystemSpecProviderEnabled){
systemMessage = prepareSystemSpecs(memoryId, method, args, userMessage);
}

if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
Expand Down Expand Up @@ -197,6 +206,13 @@ private Optional<SystemMessage> prepareSystemMessage(Object memoryId, Method met
.toSystemMessage());
}

private Optional<SystemMessage> prepareSystemSpecs(Object memoryId, Method method, Object[] args, UserMessage userMessage) {
return findDynamicSystemMessageTemplate(memoryId, method, userMessage)
.map(systemMessageTemplate -> PromptTemplate.from(systemMessageTemplate)
.apply(findTemplateVariables(systemMessageTemplate, method, args))
.toSystemMessage());
}

private Optional<String> findSystemMessageTemplate(Object memoryId, Method method) {
dev.langchain4j.service.SystemMessage annotation = method.getAnnotation(dev.langchain4j.service.SystemMessage.class);
if (annotation != null) {
Expand All @@ -206,6 +222,34 @@ private Optional<String> findSystemMessageTemplate(Object memoryId, Method metho
return context.systemMessageProvider.apply(memoryId);
}

private Optional<String> findDynamicSystemMessageTemplate(Object memoryId, Method method, UserMessage userMessage) {
List<dev.langchain4j.model.SystemSpec> systemSpecs;
RegisterSystemSpecs annotation = method.getAnnotation(RegisterSystemSpecs.class);

if (annotation != null) {
SystemSpec[] systemSpecAnnotations = annotation.value();
systemSpecs = Arrays.stream(systemSpecAnnotations)
.map(systemSpec -> dev.langchain4j.model.SystemSpec.builder()
.name(systemSpec.name())
.description(systemSpec.description())
.template(systemSpec.template())
.delimiter(systemSpec.delimiter())
.build()).collect(Collectors.toList());
} else {
systemSpecs = context.systemSpecProvider
.apply(memoryId)
.orElse(Collections.emptyList());
}

return Optional.of(
getSpecsTemplate(
method,
context,
systemSpecs,
userMessage)
);
}

private static Map<String, Object> findTemplateVariables(String template, Method method, Object[] args) {
Parameter[] parameters = method.getParameters();

Expand Down Expand Up @@ -337,6 +381,23 @@ private static String getTemplate(Method method, String type, String resource, S
return messageTemplate;
}

@NotNull
private static String getSpecsTemplate(Method method, AiServiceContext context, List<dev.langchain4j.model.SystemSpec> systemSpecs, UserMessage userMessage) {
dev.langchain4j.model.SystemSpec systemSpecResult = fetchSystemSpecUsingAI(
userMessage.singleText(),
systemSpecs,
context.chatModel
);

return getTemplate(
method,
"System",
"",
systemSpecResult.getTemplate(),
systemSpecResult.getDelimiter()
);
}

private static String getResourceText(Class<?> clazz, String name) {
return getText(clazz.getResourceAsStream(name));
}
Expand Down

0 comments on commit bc217c6

Please sign in to comment.