Skip to content

Commit

Permalink
[FEATURE] Retrieve Documents when using AiService
Browse files Browse the repository at this point in the history
  • Loading branch information
KaisNeffati committed May 6, 2024
1 parent f7ac5f6 commit 7cfac82
Show file tree
Hide file tree
Showing 14 changed files with 295 additions and 124 deletions.
12 changes: 6 additions & 6 deletions docs/docs/tutorials/7-rag.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,16 @@ interface Assistant {
```
`WithSources` class contains the information used to augment the user message :
* `response` : The response to the user's input.
* `augmentedMessage`: A wrapper for augmentation details.
* `usermessage`: The augmented user message that was sent to the LLM.
* `contents`: The list of documents used to enrich the user message including the `metadata` of each document.
* `retrievedContents`: The list of documents used to enrich the user message including the `metadata` of each document.

Attempting to use the `WithSources` class without specifying a generic type will result in an `IllegalArgumentException`. For instance:
```java
Using the `WithSources` class without specifying a generic type will lead to an `IllegalArgumentException`. For example:

WithSources chat(String userMessage); // Throw an IllegalArgumentException
```java
WithSources chat(String userMessage); // Throw an IllegalArgumentException
```

Additionally, attempting to create an `AiService` without a `contentRetriever` while utilizing `WithSources` will also result in an `IllegalArgumentException`.

## RAG APIs
LangChain4j offers a rich set of APIs to make it easy for you to build custom RAG pipelines,
ranging from simple ones to advanced ones.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dev.langchain4j.data.message;


import dev.langchain4j.rag.query.Metadata;
import lombok.Builder;
import lombok.Getter;

/**
* Represents a request for augmentation.
*/
@Getter
@Builder
public class AugmentationRequest {
/**
* The user message to be augmented.
*/
private final UserMessage userMessage;

/**
* Additional metadata related to the augmentation request.
*/
private final Metadata metadata;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dev.langchain4j.data.message;

import dev.langchain4j.rag.content.Content;
import lombok.Builder;
import lombok.Getter;

import java.util.List;

/**
* Represents the result of an augmentation process.
*/
@Getter
@Builder
public class AugmentationResult {
/**
* The augmented user message after processing.
*/
private final UserMessage augmentedUserMessage;

/**
* The list of contents retrieved during augmentation.
*/
private final List<Content> retrievedContents;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.langchain4j.rag;

import dev.langchain4j.data.message.AugmentedMessage;
import dev.langchain4j.data.message.AugmentationRequest;
import dev.langchain4j.data.message.AugmentationResult;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
Expand Down Expand Up @@ -96,9 +97,8 @@
* @see DefaultQueryRouter
* @see DefaultContentAggregator
* @see DefaultContentInjector
* @see AugmentedMessage
*/
public class DefaultRetrievalAugmentor implements RetrievalAugmentor<AugmentedMessage> {
public class DefaultRetrievalAugmentor implements RetrievalAugmentor {

private static final Logger log = LoggerFactory.getLogger(DefaultRetrievalAugmentor.class);

Expand All @@ -122,7 +122,18 @@ public DefaultRetrievalAugmentor(QueryTransformer queryTransformer,
}

@Override
public AugmentedMessage augment(UserMessage userMessage, Metadata metadata) {
public UserMessage augment(UserMessage userMessage, Metadata metadata) {
AugmentationRequest augmentationRequest = AugmentationRequest.builder()
.userMessage(userMessage)
.metadata(metadata)
.build();
return augment(augmentationRequest).getAugmentedUserMessage();
}

@Override
public AugmentationResult augment(AugmentationRequest augmentationRequest) {
UserMessage userMessage = augmentationRequest.getUserMessage();
Metadata metadata = augmentationRequest.getMetadata();

Query originalQuery = Query.from(userMessage.text(), metadata);

Expand Down Expand Up @@ -150,9 +161,9 @@ public AugmentedMessage augment(UserMessage userMessage, Metadata metadata) {
UserMessage augmentedUserMessage = contentInjector.inject(contents, userMessage);
log(augmentedUserMessage);

return AugmentedMessage.builder()
.userMessage(augmentedUserMessage)
.contents(contents)
return AugmentationResult.builder()
.augmentedUserMessage(augmentedUserMessage)
.retrievedContents(contents)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dev.langchain4j.rag;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.message.AugmentationRequest;
import dev.langchain4j.data.message.AugmentationResult;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.rag.query.Metadata;

Expand All @@ -14,14 +16,29 @@
* @see DefaultRetrievalAugmentor
*/
@Experimental
public interface RetrievalAugmentor<T> {
public interface RetrievalAugmentor {

/**
* Augments the provided {@link UserMessage} with retrieved content.
*
* @param userMessage The {@link UserMessage} to be augmented.
* @param metadata The {@link Metadata} that may be useful or necessary for retrieval and augmentation.
* @return The augmented {@link UserMessage}.
* @deprecated This method is deprecated. Use {@link #augment(AugmentationRequest)} instead.
*/
T augment(UserMessage userMessage, Metadata metadata);
@Deprecated
UserMessage augment(UserMessage userMessage, Metadata metadata);

/**
* Augments the provided {@link AugmentationRequest} with retrieved content.
*
* @param augmentationRequest The {@link AugmentationRequest} containing the user message and metadata.
* @return The {@link AugmentationResult} containing the augmented user message.
*/
default AugmentationResult augment(AugmentationRequest augmentationRequest) { // new API
UserMessage augmented = augment(augmentationRequest.getUserMessage(), augmentationRequest.getMetadata());
return AugmentationResult.builder()
.augmentedUserMessage(augmented)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void should_augment_user_message(Executor executor) {

ContentInjector contentInjector = spy(new TestContentInjector());

DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryTransformer(queryTransformer)
.queryRouter(queryRouter)
.contentAggregator(contentAggregator)
Expand All @@ -67,7 +67,7 @@ void should_augment_user_message(Executor executor) {
Metadata metadata = Metadata.from(userMessage, null, null);

// when
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata).getUserMessage();
UserMessage augmented = retrievalAugmentor.augment(userMessage, metadata);

// then
assertThat(augmented.singleText()).isEqualTo(
Expand Down Expand Up @@ -126,7 +126,7 @@ void should_not_augment_when_router_does_not_return_retrievers(Executor executor
List<ContentRetriever> retrievers = emptyList();
QueryRouter queryRouter = spy(new TestQueryRouter(retrievers));

DefaultRetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
RetrievalAugmentor retrievalAugmentor = DefaultRetrievalAugmentor.builder()
.queryRouter(queryRouter)
.executor(executor)
.build();
Expand All @@ -136,7 +136,7 @@ void should_not_augment_when_router_does_not_return_retrievers(Executor executor
Metadata metadata = Metadata.from(userMessage, null, null);

// when
UserMessage augmentedUserMessage = retrievalAugmentor.augment(userMessage, metadata).getUserMessage();
UserMessage augmentedUserMessage = retrievalAugmentor.augment(userMessage, metadata);

// then
assertThat(augmentedUserMessage).isEqualTo(userMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import java.net.Proxy;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package dev.langchain4j.chain;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.AugmentedMessage;
import dev.langchain4j.data.message.AugmentationRequest;
import dev.langchain4j.data.message.AugmentationResult;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.*;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.service.AiServices;
Expand All @@ -31,7 +33,7 @@ public class ConversationalRetrievalChain implements Chain<String, String> {

private final ChatLanguageModel chatLanguageModel;
private final ChatMemory chatMemory;
private final RetrievalAugmentor<AugmentedMessage> retrievalAugmentor;
private final RetrievalAugmentor retrievalAugmentor;

public ConversationalRetrievalChain(ChatLanguageModel chatLanguageModel,
ChatMemory chatMemory,
Expand Down Expand Up @@ -78,7 +80,12 @@ public String execute(String query) {

UserMessage userMessage = UserMessage.from(query);
Metadata metadata = Metadata.from(userMessage, chatMemory.id(), chatMemory.messages());
userMessage = retrievalAugmentor.augment(userMessage, metadata).getUserMessage();
AugmentationRequest augmentationRequest = AugmentationRequest.builder()
.userMessage(userMessage)
.metadata(metadata)
.build();
AugmentationResult augmentationResult = retrievalAugmentor.augment(augmentationRequest);
userMessage = augmentationResult.getAugmentedUserMessage();
chatMemory.add(userMessage);

AiMessage aiMessage = chatLanguageModel.generate(chatMemory.messages()).content();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AugmentedMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
Expand Down Expand Up @@ -32,7 +31,7 @@ public class AiServiceContext {
public List<ToolSpecification> toolSpecifications;
public Map<String, ToolExecutor> toolExecutors;

public RetrievalAugmentor<AugmentedMessage> retrievalAugmentor;
public RetrievalAugmentor retrievalAugmentor;

public Function<Object, Optional<String>> systemMessageProvider = DEFAULT_MESSAGE_PROVIDER;

Expand Down
25 changes: 25 additions & 0 deletions langchain4j/src/main/java/dev/langchain4j/service/AiServices.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import dev.langchain4j.retriever.Retriever;
import dev.langchain4j.spi.services.AiServicesFactory;

import java.lang.reflect.AnnotatedType;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
Expand All @@ -30,6 +33,7 @@

import static dev.langchain4j.agent.tool.ToolSpecifications.toolSpecificationFrom;
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
import static dev.langchain4j.internal.Exceptions.illegalArgument;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.stream.Collectors.toList;
Expand Down Expand Up @@ -404,6 +408,27 @@ protected void performBasicValidation() {
}
}

/**
* Validates the return type of the method to ensure it adheres to the requirements for WithSources<T>.
*
* @param method The method to be validated.
* @throws IllegalArgumentException if WithSources does not have a generic class defined, or if a retrieval augmentor is not provided.
*/
protected void withSourcesValidation(Method method) {
Class<?> returnType = method.getReturnType();
if (returnType != WithSources.class) {
return;
}
AnnotatedType annotatedReturnType = method.getAnnotatedReturnType();
Type withSourcesAnnotatedType = annotatedReturnType.getType();
if (!(withSourcesAnnotatedType instanceof ParameterizedType)) {
throw illegalArgument("Method '%s' must return WithSources with a defined generic class.", method.getName());
}
if (context.retrievalAugmentor == null) {
throw illegalArgument("Method '%s' requires a retrieval augmentor to use WithSources.", method.getName());
}
}

public static List<ChatMessage> removeToolMessages(List<ChatMessage> messages) {
return messages.stream()
.filter(it -> !(it instanceof ToolExecutionResultMessage))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public T build() {
throw illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. " +
"Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
}
withSourcesValidation(method);
}

Object proxyInstance = Proxy.newProxyInstance(
Expand All @@ -88,18 +89,32 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio

Optional<SystemMessage> systemMessage = prepareSystemMessage(memoryId, method, args);
UserMessage userMessage = prepareUserMessage(method, args);
AugmentedMessage augmentedMessage = null;
AugmentationResult augmentationResult = null;
if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
augmentedMessage = context.retrievalAugmentor.augment(userMessage, metadata);
userMessage = augmentedMessage.getUserMessage();
AugmentationRequest augmentationRequest = AugmentationRequest.builder()
.userMessage(userMessage)
.metadata(metadata)
.build();
augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
userMessage = augmentationResult.getAugmentedUserMessage();
}

// TODO give user ability to provide custom OutputParser
Class<?> returnType = method.getReturnType();
boolean isWithSources = false;
if (returnType == WithSources.class) {
AnnotatedType annotatedReturnType = method.getAnnotatedReturnType();
ParameterizedType type = (ParameterizedType) annotatedReturnType.getType();
Type[] typeArguments = type.getActualTypeArguments();
for (Type typeArg : typeArguments) {
returnType = Class.forName(typeArg.getTypeName());
}
isWithSources = true;
}
String outputFormatInstructions = outputFormatInstructions(returnType);
userMessage = UserMessage.from(userMessage.text() + outputFormatInstructions);

Expand Down Expand Up @@ -175,26 +190,13 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Exceptio
}

response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
if (returnType != WithSources.class) {
return parse(response, returnType);
}
AnnotatedType annotatedReturnType = method.getAnnotatedReturnType();

Type withSourcesAnnotatedType = annotatedReturnType.getType();
if (withSourcesAnnotatedType instanceof ParameterizedType) {
ParameterizedType type = (ParameterizedType) withSourcesAnnotatedType;
Type[] typeArguments = type.getActualTypeArguments();
for (Type typeArg : typeArguments) {
returnType = Class.forName(typeArg.getTypeName());
}
} else {
throw illegalArgument("WithSources needs to have a generic class defined for the following method : %s", method.getName());
}
Object parsedResponse = parse(response, returnType);
return WithSources.builder()
return isWithSources ? WithSources.builder()
.response(parsedResponse)
.augmentedMessage(augmentedMessage)
.build();
.retrievedContents(Optional.ofNullable(augmentationResult)
.map(AugmentationResult::getRetrievedContents)
.orElse(Collections.emptyList()))
.build() : parsedResponse;
}

private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
Expand Down

0 comments on commit 7cfac82

Please sign in to comment.