聊聊Spring AI的RAG

本文主要研究一下Spring AI的RAG

Sequential RAG Flows

Naive RAG

Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
        .documentRetriever(VectorStoreDocumentRetriever.builder()
                .similarityThreshold(0.50)
                .vectorStore(vectorStore)
                .build())
        .queryAugmenter(ContextualQueryAugmenter.builder()
                .allowEmptyContext(true)
                .build())
        .build();

String answer = chatClient.prompt()
        .advisors(retrievalAugmentationAdvisor)
        .user(question)
        .call()
        .content();

allowEmptyContext為true告訴大模型不回答context為empty的問題

Advanced RAG

Advisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
        .queryTransformers(RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .build())
        .documentRetriever(VectorStoreDocumentRetriever.builder()
                .similarityThreshold(0.50)
                .vectorStore(vectorStore)
                .build())
        .build();

String answer = chatClient.prompt()
        .advisors(retrievalAugmentationAdvisor)
        .user(question)
        .call()
        .content();

Advanced RAG可以設(shè)置queryTransformers來進(jìn)行查詢轉(zhuǎn)換

Modular RAG

Spring AI受Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks啟發(fā)實現(xiàn)了Modular RAG,主要分為如下幾個階段:Pre-Retrieval、Retrieval、Post-Retrieval、Generation

Pre-Retrieval

增強(qiáng)和轉(zhuǎn)換用戶輸入,使其更有效地執(zhí)行檢索任務(wù),解決格式不正確的查詢、query 語義不清晰、或不受支持的語言等。

1. QueryAugmenter 查詢增強(qiáng)

使用附加的上下文數(shù)據(jù)信息增強(qiáng)用戶query,提供大模型回答問題時的必要上下文信息;

  • ContextualQueryAugmenter使用上下文來增強(qiáng)query
QueryAugmenter augmenter = ContextualQueryAugmenter. builder()    
        .allowEmptyContext(false)    
        .build(); 
Query augmentedQuery = augmenter.augment(query, documents);

2. QueryTransformer 查詢改寫

因為用戶的輸入通常是片面的,關(guān)鍵信息較少,不便于大模型理解和回答問題。因此需要使用prompt調(diào)優(yōu)手段或者大模型改寫用戶query;
當(dāng)使用QueryTransformer時建議設(shè)置比較低的temperature(比如0.0)來確保結(jié)果的準(zhǔn)確性
它有CompressionQueryTransformer、RewriteQueryTransformer、TranslationQueryTransformer三種實現(xiàn)

  • CompressionQueryTransformer使用大模型來壓縮會話歷史
Query query = Query.builder()
        .text("And what is its second largest city?")
        .history(new UserMessage("What is the capital of Denmark?"),
                new AssistantMessage("Copenhagen is the capital of Denmark."))
        .build();

QueryTransformer queryTransformer = CompressionQueryTransformer.builder()
        .chatClientBuilder(chatClientBuilder)
        .build();

Query transformedQuery = queryTransformer.transform(query);
  • RewriteQueryTransformer使用大模型來重寫query
Query query = new Query("I'm studying machine learning. What is an LLM?");

QueryTransformer queryTransformer = RewriteQueryTransformer.builder()
        .chatClientBuilder(chatClientBuilder)
        .build();

Query transformedQuery = queryTransformer.transform(query);
  • TranslationQueryTransformer使用大模型來翻譯query
Query query = new Query("Hvad er Danmarks hovedstad?");

QueryTransformer queryTransformer = TranslationQueryTransformer.builder()
        .chatClientBuilder(chatClientBuilder)
        .targetLanguage("english")
        .build();

Query transformedQuery = queryTransformer.transform(query);

3. QueryExpander 查詢擴(kuò)展

將用戶 query 擴(kuò)展為多個語義不同的變體以獲得不同視角,有助于檢索額外的上下文信息并增加找到相關(guān)結(jié)果的機(jī)會。

  • MultiQueryExpander使用大模型擴(kuò)展query
MultiQueryExpander queryExpander = MultiQueryExpander.builder()
    .chatClientBuilder(chatClientBuilder)
    .numberOfQueries(3)
    .includeOriginal(false) // 默認(rèn)會包含原始query,設(shè)置為false表示不包含
    .build();
List<Query> queries = expander.expand(new Query("How to run a Spring Boot app?"));

Retrieval

負(fù)責(zé)查詢向量存儲等數(shù)據(jù)系統(tǒng)并檢索和用戶query相關(guān)性最高的Document。

1. DocumentRetriever 檢索器

根據(jù) QueryExpander 使用不同的數(shù)據(jù)源進(jìn)行檢索,例如 搜索引擎、向量存儲、數(shù)據(jù)庫或知識圖等;它主要有VectorStoreDocumentRetriever、WebSearchRetriever兩個實現(xiàn)

  • VectorStoreDocumentRetriever
DocumentRetriever retriever = VectorStoreDocumentRetriever.builder()
    .vectorStore(vectorStore)
    .similarityThreshold(0.73)
    .topK(5)
    .filterExpression(new FilterExpressionBuilder()
        .eq("genre", "fairytale")
        .build())
    .build();
List<Document> documents = retriever.retrieve(new Query("What is the main character of the story?"));

2. DocumentJoiner

將從多個query和從多個數(shù)據(jù)源檢索到的Document合并為一個Document集合;它有ConcatenationDocumentJoiner實現(xiàn)

  • ConcatenationDocumentJoiner
Map<Query, List<List<Document>>> documentsForQuery = ...
DocumentJoiner documentJoiner = new ConcatenationDocumentJoiner();
List<Document> documents = documentJoiner.join(documentsForQuery);

Post-Retrieval

負(fù)責(zé)處理檢索到的 Document 以獲得最佳的輸出結(jié)果,解決模型中的中間丟失和上下文長度限制等。

  1. DocumentRanker:根據(jù)Document和用戶query的相關(guān)性對Document進(jìn)行排序和排名;
  2. DocumentSelector:用于從檢索到的Document列表中刪除不相關(guān)或冗余文檔;
  3. DocumentCompressor:用于壓縮每個Document,減少檢索到的信息中的噪音和冗余。

Generation

生成用戶 Query 對應(yīng)的大模型輸出。

源碼

org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java

    public static final class Builder {

        private List<QueryTransformer> queryTransformers;

        private QueryExpander queryExpander;

        private DocumentRetriever documentRetriever;

        private DocumentJoiner documentJoiner;

        private QueryAugmenter queryAugmenter;

        private TaskExecutor taskExecutor;

        private Scheduler scheduler;

        private Integer order;

        private Builder() {
        }

        //......
    }   

RetrievalAugmentationAdvisor的Builder提供了Pre-Retrieval(queryAugmenterqueryTransformersqueryExpander)、Retrieval(documentRetrieverdocumentJoiner)這幾個組件的配置。

示例

ModuleRAGBasicController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGBasicController {

    private final ChatClient chatClient;
    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGBasicController(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {

        this.chatClient = chatClientBuilder.build();
        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(VectorStoreDocumentRetriever.builder()
                        .similarityThreshold(0.50)
                        .vectorStore(vectorStore)
                        .build()
                ).build();
    }

    @GetMapping("/rag/basic")
    public String chatWithDocument(@RequestParam("prompt") String prompt) {

        return chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .user(prompt)
                .call()
                .content();
    }

}

ModuleRAGCompressionController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGCompressionController {

    private final ChatClient chatClient;

    private final MessageChatMemoryAdvisor chatMemoryAdvisor;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGCompressionController(
            ChatClient.Builder chatClientBuilder,
            ChatMemory chatMemory,
            VectorStore vectorStore) {

        this.chatClient = chatClientBuilder.build();

        this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
                .build();

        var documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();

        var queryTransformer = CompressionQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .queryTransformers(queryTransformer)
                .build();
    }

    @PostMapping("/rag/compression/{chatId}")
    public String rag(
            @RequestBody String prompt,
            @PathVariable("chatId") String conversationId
    ) {

        return chatClient.prompt()
                .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
                .advisors(advisors -> advisors.param(
                        AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
                .user(prompt)
                .call()
                .content();
    }

}

ModuleRAGMemoryController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGMemoryController {

    private final ChatClient chatClient;

    private final MessageChatMemoryAdvisor chatMemoryAdvisor;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGMemoryController(
            ChatClient.Builder chatClientBuilder,
            ChatMemory chatMemory,
            VectorStore vectorStore
    ) {

        this.chatClient = chatClientBuilder.build();
        this.chatMemoryAdvisor = MessageChatMemoryAdvisor.builder(chatMemory)
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(VectorStoreDocumentRetriever.builder()
                        .similarityThreshold(0.50)
                        .vectorStore(vectorStore)
                        .build())
                .build();
    }

    @PostMapping("/rag/memory/{chatId}")
    public String chatWithDocument(
            @RequestBody String prompt,
            @PathVariable("chatId") String conversationId
    ) {

        return chatClient.prompt()
                .advisors(chatMemoryAdvisor, retrievalAugmentationAdvisor)
                .advisors(advisors -> advisors.param(
                        AbstractChatMemoryAdvisor.CHAT_MEMORY_CONVERSATION_ID_KEY, conversationId))
                .user(prompt)
                .call()
                .content();
    }

}

ModuleRAGRewriteController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGRewriteController {

    private final ChatClient chatClient;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGRewriteController(
            ChatClient.Builder chatClientBuilder,
            VectorStore vectorStore
    ) {

        this.chatClient = chatClientBuilder.build();

        var documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();

        var queryTransformer = RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .targetSearchSystem("vector store")
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .queryTransformers(queryTransformer)
                .build();
    }

    @PostMapping("/rag/rewrite")
    public String rag(@RequestBody String prompt) {

        return chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .user(prompt)
                .call()
                .content();
    }
}

ModuleRAGTranslationController

@RestController
@RequestMapping("/module-rag")
public class ModuleRAGTranslationController {

    private final ChatClient chatClient;

    private final RetrievalAugmentationAdvisor retrievalAugmentationAdvisor;

    public ModuleRAGTranslationController(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) {
        this.chatClient = chatClientBuilder.build();

        var documentRetriever = VectorStoreDocumentRetriever.builder()
                .vectorStore(vectorStore)
                .similarityThreshold(0.50)
                .build();

        var queryTransformer = TranslationQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .targetLanguage("english")
                .build();

        this.retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder()
                .documentRetriever(documentRetriever)
                .queryTransformers(queryTransformer)
                .build();
    }

    @PostMapping("/rag/translation")
    public String rag(@RequestBody String prompt) {

        return chatClient.prompt()
                .advisors(retrievalAugmentationAdvisor)
                .user(prompt)
                .call()
                .content();
    }

}

小結(jié)

Spring AI通過RetrievalAugmentationAdvisor提供了開箱即用的RAG flows,主要有兩大類,一是Sequential RAG Flows(Naive RAGAdvanced RAG),另外Spring AI受Modular RAG: Transforming RAG Systems into LEGO-like Reconfigurable Frameworks啟發(fā)實現(xiàn)了Modular RAG,主要分為如下幾個階段:Pre-Retrieval、Retrieval、Post-Retrieval、Generation這幾個階段。RetrievalAugmentationAdvisor的Builder提供了Pre-Retrieval(queryAugmenterqueryTransformersqueryExpander)、Retrieval(documentRetrieverdocumentJoiner)這幾個組件的配置。

doc

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容