Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature] ES VectorStore #1483

Closed
IamWWT opened this issue May 2, 2024 · 1 comment · Fixed by #1500
Closed

[New Feature] ES VectorStore #1483

IamWWT opened this issue May 2, 2024 · 1 comment · Fixed by #1500

Comments

@IamWWT
Copy link
Contributor

IamWWT commented May 2, 2024

已经验证通过的功能:

1)可以新建知识空间(仅支持英文,不支持中文)
2)可以上传文档进行EMBEDDING,
3)可以逐个删除上传的每一个文档。
4)可以搜索对话。

涉及修改的文件内容如下:

1).env 添加如下

VECTOR_STORE_TYPE=ElasticSearch
ElasticSearch_URL=127.0.0.1
ElasticSearch_PORT=9200
ElasticSearch_USERNAME=elastic
ElasticSearch_PASSWORD=i=+iLw9y0Jduq86XTi6W

2)dbgpt/_private/config.py 添加如下

    self.ElasticSearch_URL = os.getenv("ElasticSearch_URL", "127.0.0.1")
    self.ElasticSearch_PORT = os.getenv("ElasticSearch_PORT", "9200")
    self.ElasticSearch_USERNAME = os.getenv("ElasticSearch_USERNAME", None)
    self.ElasticSearch_PASSWORD = os.getenv("ElasticSearch_PASSWORD", None)

3)dbgpt/app/knowledge/service.py 的 def delete_document():修改如下

    def delete_document(self, space_name: str, doc_name: str):
        """delete document
        Args:
            - space_name: knowledge space name
            - doc_name: doocument name
        """ 
        document_query = KnowledgeDocumentEntity(doc_name=doc_name, space=space_name)
        documents = knowledge_document_dao.get_documents(document_query) 
        if len(documents) != 1:
            raise Exception(f"there are no or more than one document called {doc_name}")
        vector_ids = documents[0].vector_ids 
        if vector_ids is not None:
            ## wwt add
            embedding_factory = CFG.SYSTEM_APP.get_component("embedding_factory", EmbeddingFactory)
            embedding_fn = embedding_factory.create(model_name=EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL])
            ### wwt 修改
            if CFG.VECTOR_STORE_TYPE == "Milvus":
                config = VectorStoreConfig(name=space_name,            
                                        embedding_fn=embedding_fn,  ## wwt add
                                        max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,   ## wwt add
                                        user=CFG.MILVUS_USERNAME,
                                        password=CFG.MILVUS_PASSWORD,
                                        )
            elif CFG.VECTOR_STORE_TYPE == "ElasticSearch":
                logger.info(f"wwt add 正在删除ES类型的文档。")
                config = VectorStoreConfig(name=space_name, embedding_fn=embedding_fn,  ## wwt add
                                        max_chunks_once_load=CFG.KNOWLEDGE_MAX_CHUNKS_ONCE_LOAD,   ## wwt add
                                        user=CFG.ElasticSearch_USERNAME,
                                        password=CFG.ElasticSearch_PASSWORD,
                                        )
            elif CFG.VECTOR_STORE_TYPE == "Chroma":
                config = VectorStoreConfig(name=space_name)
            else:
                config = VectorStoreConfig(name=space_name) 
            vector_store_connector = VectorStoreConnector(
                vector_store_type=CFG.VECTOR_STORE_TYPE,
                vector_store_config=config,
            ) 
            # delete vector by ids 
            vector_store_connector.delete_by_ids(vector_ids)
        # delete chunks 
        document_chunk_dao.raw_delete(documents[0].id)
        # delete document
        return knowledge_document_dao.raw_delete(document_query)

4)dbgpt/storage/vector_store/init.py 新增修改如下:

def _import_elastic() -> Any:
    from dbgpt.storage.vector_store.elastic_store import ElasticStore

    return ElasticStore

def __getattr__(name: str) -> Any:
    if name == "Chroma":
        return _import_chroma()
    elif name == "Milvus":
        return _import_milvus()
    elif name == "Weaviate":
        return _import_weaviate()
    elif name == "PGVector":
        return _import_pgvector()
    elif name == "ElasticSearch":
        return _import_elastic()
    else:
        raise AttributeError(f"Could not find: {name}")


__all__ = ["Chroma", "Milvus", "Weaviate", "PGVector", "ElasticSearch"]

5)dbgpt/storage/vector_store/ 新增文件elastic_store.py如下:

"""Elasticsearch vector store for 全文索引---- for 全文检索."""
from __future__ import annotations

import json
import logging
import os
from typing import Any, Iterable, List, Optional

from dbgpt._private.pydantic import Field
from dbgpt.core import Chunk, Embeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.storage.vector_store.base import (
    _COMMON_PARAMETERS,
    VectorStoreBase,
    VectorStoreConfig,
)
from dbgpt.storage.vector_store.filters import FilterOperator, MetadataFilters
from dbgpt.util import string_utils
from dbgpt.util.i18n_utils import _

logger = logging.getLogger(__name__)

try:
    import jieba
    import jieba.analyse 
    from langchain.schema import Document
    from langchain.vectorstores.elasticsearch import ElasticsearchStore
    from elasticsearch import Elasticsearch 
except ImportError:
    raise ValueError(
        "Could not import elasticsearch python package. "
        "Please install it with `pip install elasticsearch`."
    )  


@register_resource(
    _("ElasticSearch Vector Store"),
    "elasticsearch_vector_store",
    category=ResourceCategory.VECTOR_STORE,
    parameters=[
        *_COMMON_PARAMETERS,
        Parameter.build_from(
            _("Uri"),
            "uri",
            str,
            description=_(
                "The uri of elasticsearch store, if not set, will use the default " "uri."
            ),
            optional=True,
            default="localhost",
        ),
        Parameter.build_from(
            _("Port"),
            "port",
            str,
            description=_(
                "The port of elasticsearch store, if not set, will use the default " "port."
            ),
            optional=True,
            default="9200",
        ),
        Parameter.build_from(
            _("Alias"),
            "alias",
            str,
            description=_(
                "The alias of elasticsearch store, if not set, will use the default " "alias."
            ),
            optional=True,
            default="default",
        ),
        Parameter.build_from(
            _("Index Name"),
            "index_name",
            str,
            description=_(
                "The index name of elasticsearch store, if not set, will use the "
                "default index name."
            ),
            optional=True,
            default="index_name_test",
        ),
        Parameter.build_from(
            _("Text Field"),
            "text_field",
            str,
            description=_(
                "The text field of elasticsearch store, if not set, will use the "
                "default text field."
            ),
            optional=True,
            default="content",
        ),
        Parameter.build_from(
            _("Embedding Field"),
            "embedding_field",
            str,
            description=_(
                "The embedding field of elasticsearch store, if not set, will use the "
                "default embedding field."
            ),
            optional=True,
            default="vector",
        ),
    ],
    description=_("Elasticsearch vector store."),
)
class ElasticsearchVectorConfig(VectorStoreConfig):
    """Elasticsearch vector store config."""

    class Config:
        """Config for BaseModel."""

        arbitrary_types_allowed = True

    uri: str = Field(
        default="localhost",
        description="The uri of elasticsearch store, if not set, will use the default uri.",
    )
    port: str = Field(
        default="9200",
        description="The port of elasticsearch store, if not set, will use the default port.",
    )

    alias: str = Field(
        default="default",
        description="The alias of elasticsearch store, if not set, will use the default "
        "alias.",
    )
    index_name: str = Field(
        default="index_name_test",
        description="The index name of elasticsearch store, if not set, will use the "
        "default index name.",
    )
    text_field: str = Field(
        default="content",
        description="The text field of elasticsearch store, if not set, will use the default "
        "text field.",
    )
    embedding_field: str = Field(
        default="vector",
        description="The embedding field of elasticsearch store, if not set, will use the "
        "default embedding field.",
    )
    metadata_field: str = Field(
        default="metadata",
        description="The metadata field of elasticsearch store, if not set, will use the "
        "default metadata field.",
    )
    secure: str = Field(
        default="",
        description="The secure of elasticsearch store, if not set, will use the default "
        "secure.",
    )


class ElasticStore(VectorStoreBase):
    """Elasticsearch vector store."""

    def __init__(self, vector_store_config: ElasticsearchVectorConfig) -> None:
        """Create a ElasticsearchStore instance.

        Args:
            vector_store_config (ElasticsearchVectorConfig): ElasticsearchStore config. 
        """

        connect_kwargs = {}
        elasticsearch_vector_config = vector_store_config.dict()
        self.uri = elasticsearch_vector_config.get("uri") or os.getenv(
            "ElasticSearch_URL", "localhost"
        )
        self.port = elasticsearch_vector_config.get("post") or os.getenv(
            "ElasticSearch_PORT", "9200"
        )
        self.username = elasticsearch_vector_config.get("username") or os.getenv("ElasticSearch_USERNAME")
        self.password = elasticsearch_vector_config.get("password") or os.getenv(
            "ElasticSearch_PASSWORD"
        ) 

        self.collection_name = (
            elasticsearch_vector_config.get("name") or vector_store_config.name
        )
        if string_utils.is_all_chinese(self.collection_name):
            bytes_str = self.collection_name.encode("utf-8")
            hex_str = bytes_str.hex()
            self.collection_name = hex_str
        if vector_store_config.embedding_fn is None:
            # Perform runtime checks on self.embedding to
            # ensure it has been correctly set and loaded
            raise ValueError("embedding_fn is required for ElasticSearchStore")
        self.index_name = self.collection_name.lower()
        self.embedding: Embeddings = vector_store_config.embedding_fn
        self.fields: List = [] 

        if (self.username is None) != (self.password is None):
            raise ValueError(
                "Both username and password must be set to use authentication for "
                "ElasticSearch"
            )

        if self.username:
            connect_kwargs["username"] = self.username
            connect_kwargs["password"] = self.password
 
        # 创建索引的配置===单节点情况下
        self.index_settings = { "settings": {
                                "number_of_shards": 1,
                                "number_of_replicas": 0  # 设置副本数量为0
                        }}

        """"""
        # ES python客户端连接(仅连接)
        try:
            if self.username != "" and self.password != "":
                self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}",
                                                        basic_auth=(self.username,self.password))                 
                # 不创建索引,要不然会报错
                #if not self.vector_name_exists():
                #    self.es_client_python.indices.create(index=self.index_name, body=self.index_settings)
            else:
                logger.warning("ES未配置用户名和密码")
                self.es_client_python = Elasticsearch(f"http://{self.uri}:{self.port}")
                #if not self.vector_name_exists():
                #    self.es_client_python.indices.create(index=self.index_name, body=self.index_settings)
        except ConnectionError:
            logger.error("连接到 Elasticsearch 失败!")
        except Exception as e:
            logger.error(f"ES python客户端连接(仅连接)===Error 发生 : {e}")

        # langchain ES 连接、创建索引
        try: 
            if self.username != "" and self.password != "":
                self.db_init = ElasticsearchStore(
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    query_field="context",
                    vector_query_field="dense_vector",
                    embedding=self.embedding,
                    es_user=self.username,
                    es_password=self.password
                )
            else: 
                logger.warning("ES未配置用户名和密码")
                self.db_init = ElasticsearchStore(
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    query_field="context",
                    vector_query_field="dense_vector",
                    embedding=self.embedding,
                )            
        except ConnectionError:
            print("### 连接到 Elasticsearch 失败!")
            logger.error("### 连接到 Elasticsearch 失败!")
        except Exception as e:
            logger.error(f"langchain ES 连接、创建索引===Error 发生 : {e}")
        

    def load_document(
        self,
        #docs: Iterable[str],   
        chunks: List[Chunk]
    ) -> List[str]: 
        """Add text data into ElastcSearch.
        将docs写入到ES中
        """
        logger.info("ElasticStore load document")
        try:
            # 连接 + 同时写入文档 
            texts = [chunk.content for chunk in chunks]
            metadatas = [chunk.metadata for chunk in chunks]
            ids = [chunk.chunk_id for chunk in chunks]
            if self.username != "" and self.password != "":
                logger.info(f"wwt docs metadatas[0] === ElasticsearchStore.from_texts:::{metadatas[0]}: len={len(metadatas)}")
                self.db = ElasticsearchStore.from_texts(
                    texts=texts,
                    embedding=self.embedding,
                    metadatas=metadatas,
                    ids=ids,
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    distance_strategy="COSINE",  # Defaults to COSINE. Can be one of COSINE, EUCLIDEAN_DISTANCE, or DOT_PRODUCT.
                    query_field="context",  ## Name of the field to store the texts in.
                    vector_query_field="dense_vector", # Optional. Name of the field to store the embedding vectors in. 
                    es_user=self.username,
                    es_password=self.password,
                ) 
                logger.info(f"wwt add Embedding success.......")
            else:
                self.db = ElasticsearchStore.from_documents(
                    texts=texts,
                    embedding=self.embedding,
                    metadatas=metadatas,
                    ids=ids,
                    es_url=f"http://{self.uri}:{self.port}",
                    index_name=self.index_name,
                    distance_strategy="COSINE",
                    query_field="context",
                    vector_query_field="dense_vector",
                    #verify_certs=False, 
                    ) 
            return ids
        except ConnectionError as ce:
            print(ce)
            print("连接到 Elasticsearch 失败!")
            logger.error("连接到 Elasticsearch 失败!")
        except Exception as e:
            logger.error(f"load_document===Error 发生 : {e}")
            print(e)


    def delete_by_ids(self, ids):
        """Delete vector by ids."""
        logger.info(f"1begin delete elasticsearch len ids: {len(ids)}") 
        logger.info(f"1begin delete elasticsearch type ids: {type(ids)}") 
        ids = ids.split(",")
        logger.info(f"2begin delete elasticsearch len ids: {len(ids)}") 
        logger.info(f"2begin delete elasticsearch type ids: {type(ids)}") 
        #es_client= self.db_init.connect_to_elasticsearch(
        #        es_url=f"http://{self.uri}:{self.port}",  
        #        es_user=self.username,
        #        es_password=self.password,   
        #)
        try:
            self.db_init.delete(ids=ids)  
            self.es_client_python.indices.refresh(index=self.index_name)
        except Exception as e:
            logger.error(f"Error 发生 : {e}") 
            

    def similar_search(
        self, text: str, topk: int, score_threshold: float, filters: Optional[MetadataFilters] = None
    ) -> List[Chunk]:
        """Perform a search on a query string and return results.
        # TODO: 语义分词后期配置可换
        """
        query = text
        print(
            f" similar_search 输入的query参数为:{query}") 
        query_list = jieba.analyse.textrank(query, topK=20, withWeight=False)
        if len(query_list) == 0:
            query_list = [query]
        body = {
            "query": {
                "match": {
                    "context": " ".join(query_list)
                }
            }
        }
        search_results = self.es_client_python.search(index=self.index_name, body=body, size=topk)
        search_results = search_results['hits']['hits']

        # 判断搜索结果是否为空
        if not search_results:
            return []
        
        info_docs = []
        byte_count = 0

        for result in search_results:
            index_name = result["_index"]  
            vector_doc = result["dense_vector"]  # 文本的稠密向量表示
            doc_id = result["_id"]  
            source = result["_source"]
            context = source["context"]
            metadata = source["metadata"]
            score = result["_score"]

            # 如果下一个context会超过总字节数限制,则截断context
            VS_TYPE_PROMPT_TOTAL_BYTE_SIZE = 3000   ### 每种向量库的prompt字节的最大长度,超过则截断,后面放到.env中
            if (byte_count + len(context)) > VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                context = context[:VS_TYPE_PROMPT_TOTAL_BYTE_SIZE - byte_count]

            doc_with_score = [Document(page_content=context, metadata=metadata), score, doc_id]
            info_docs.append(doc_with_score)

            byte_count += len(context)

            # 如果字节数已经达到限制,则结束循环
            if byte_count >= VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                break
        print(f"ES搜索到{len(info_docs)}个结果:")
        # 将结果写入文件
        result_file = open("es_search_results.txt", "w", encoding="utf-8")
        result_file.write(f"query:{query}")
        result_file.write(f"ES搜索到{len(info_docs)}个结果:\n")
        for item in info_docs:
            doc = item[0]
            result_file.write(doc.page_content + "\n")
            result_file.write("*" * 20)
            result_file.write("\n")
            result_file.flush()
            print(doc.page_content + "\n")
            print("*" * 20)
            print("\n")
        result_file.close()

        return [
            Chunk(
                metadata=json.loads(doc.metadata.get("metadata", "")),
                content=doc.page_content,
            )
            for doc, score, id  in info_docs
        ]

    #def similar_search_with_scores(self, text: str, topk: int, score_threshold: float,): 
    def similar_search_with_scores(
        self, text, topk, score_threshold, filters: Optional[MetadataFilters] = None
    ) -> List[Chunk]:
        """Perform a search on a query string and return results with score.

        For more information about the search parameters, take a look at the pyElasticSearch
        documentation found here:
        https://ElasticSearch.io/api-reference/pyElasticSearch/v2.2.6/Collection/search().md

        Args:
            text (str): The query text.
            topk (int): The number of similar documents to return.
            score_threshold (float): Optional, a floating point value between 0 to 1.
            filters (Optional[MetadataFilters]): Optional, metadata filters.
        Returns:
            List[Tuple[Document, float]]: Result doc and score.
        """ 

        query = text
        print(f" similar_search 输入的query参数为:{query}") 
        query_list = jieba.analyse.textrank(query, topK=20, withWeight=False)
        if len(query_list) == 0:
            query_list = [query]
        body = {
            "query": {
                "match": {
                    "context": " ".join(query_list)
                }
            }
        }
        search_results = self.es_client_python.search(index=self.index_name, body=body, size=topk)
        search_results = search_results['hits']['hits']
        # 判断搜索结果是否为空
        if not search_results:
            return []
        
        info_docs = []
        byte_count = 0

        for result in search_results:            
            # logger.info(f"wwt add query result==={result}")
            ## 全部列出了
            index_name = result["_index"]  
            #vector_doc = result["dense_vector"]  # 文本的稠密向量表示
            doc_id = result["_id"]  
            source = result["_source"] #  源头
            context = source["context"]  # 文本内容
            metadata = source["metadata"]  ## 文本来源路径
            score = result["_score"] / 100  # 分数,100分zhi

            # 如果下一个context会超过总字节数限制,则截断context
            VS_TYPE_PROMPT_TOTAL_BYTE_SIZE = 3000   ### 每种向量库的prompt字节的最大长度,超过则截断,后面放到.env中
            if (byte_count + len(context)) > VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                context = context[:VS_TYPE_PROMPT_TOTAL_BYTE_SIZE - byte_count]

            doc_with_score = [Document(page_content=context, metadata=metadata), score, doc_id]
            info_docs.append(doc_with_score)

            byte_count += len(context)

            # 如果字节数已经达到限制,则结束循环
            if byte_count >= VS_TYPE_PROMPT_TOTAL_BYTE_SIZE:
                break
        print(f"ES搜索到{len(info_docs)}个结果:")
        logger.info(f"ES搜索到{len(info_docs)}个结果:")
        # 将结果写入文件
        result_file = open("es_search_results.txt", "w", encoding="utf-8")
        result_file.write(f"query:{query} \n")
        result_file.write(f"ES搜索到{len(info_docs)}个结果:\n")
        for item in info_docs:
            doc = item[0]
            result_file.write(doc.page_content + "\n")
            result_file.write("*" * 50)
            result_file.write("\n")
            result_file.flush()
            print(doc.page_content + "\n")
            print("*" * 50)
            print("\n\n")
        result_file.close()
         
        if any(score < 0.0 or score > 1.0 for _, score, _ in info_docs):
            logger.warning(
                "similarity score need between" f" 0 and 1, got {info_docs}"
            )

        logger.info(f"wwt add score_threshold: {score_threshold}")
        if score_threshold is not None:
            docs_and_scores = [
                Chunk(
                    metadata=doc.metadata,
                    content=doc.page_content,
                    score=score,
                    chunk_id=id,
                )
                for doc, score, id in info_docs
                if score >= score_threshold
            ]
            if len(docs_and_scores) == 0:
                logger.warning(
                    "No relevant docs were retrieved using the relevance score"
                    f" threshold {score_threshold}"
                )
        return docs_and_scores
 

    def vector_name_exists(self):
        """Whether vector name exists.""" 
        """is vector store name exist."""
        return self.es_client_python.indices.exists(index=self.index_name)
    

    def delete_vector_name(self, vector_name: str):
        """Delete vector name/index_name."""  
        """从知识库(知识库名的小写部分)删除全部向量"""
        if self.es_client_python.indices.exists(index=self.index_name):
            self.es_client_python.indices.delete(index=self.index_name)
            #self.es_client_python.indices.delete(index=self.kb_name)
 
@Aries-ckt
Copy link
Collaborator

hi, @IamWWT, amazing feature, can you make pull request for ElasticSearchStore ?

@Aries-ckt Aries-ckt changed the title [New Feature] ES向量库链接已经验证通过 [New Feature] ES VectorStore May 4, 2024
@csunny csunny closed this as completed May 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants