大模型的开发应用(十三):基于RAG的法律助手项目(上):总体流程简易实现

1 项目介绍

本项目是制作一款专注于劳动纠纷的法律助手,使用RAG技术实现。

1.1 方案选型

制作一个基于大模型的专家系统,优先使用 RAG 来实现,因为相比于微调,RAG 有以下几点优势:

在这里插入图片描述

当然,有的时候 RAG 会和微调一起使用,当面对的专业领域非常的小众和冷门,原始模型根本无法理解用户问题,那么需要先微调,让模型对本专业的问题有一定的理解能力。简单来讲,就是先让模型理解你想问什么,再看知识库里有没有。

本项目的基模型,我们选择 Qwen1.5-1.8B-Chat,问几个问题看它对劳动法是不是一点都不懂:

from modelscope import AutoModelForCausalLM, AutoTokenizer

model_name = "/data/coding/models/Qwen/Qwen1.5-1.8B-Chat"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

prompt = "我在目前这家公司工作三年了,现在公司想开除我,请问会如何赔偿我?"
messages = [
    {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=512
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

print(response)

输出:

在许多情况下,公司可能会根据您的职位、工作经验、绩效表现以及相关的法律法规和公司政策来决定是否解雇您。以下是一些可能的赔偿方案:

1. 薪资补偿:如果公司在解除劳动合同前已向您支付了合理的工资或薪酬,并且没有明确表示要解雇您,那么通常您将获得一份工资补偿。具体金额取决于您的职位等级、工作年限、月度/年度基本工资标准等因素。这种补偿通常是基于您在过去三年的工作经历和实际贡献的。

2. 绩效奖金:如果您是公司的关键员工或者有显著的工作成果,公司可能会在解除合同之前向您提供一定的绩效奖金或其他形式的经济激励。这可能包括现金奖励、股权激励计划(如股票期权)、福利金等。

3. 停职期间工资:对于短期离职或临时离职的员工,公司通常会提供一定期限的停职期间工资。这种工资通常包括基本工资和岗位补贴,但通常不包括年终奖或其他具有激励性的补偿。

4. 其他福利:除了上述补偿外,公司还可能会提供一些其他福利,如健康保险、退休金计划、培训机会、休假制度等。这些福利可以帮助您维持生活质量和职业发展。

5. 解除通知:公司一般会在与您达成协议后,为您的个人文件(如简历、员工手册、劳动合同)制作一份解除劳动关系的通知书,告知您因个人原因无法继续履行工作职责,并要求您配合公司完成交接工作,如有需要可以领取相应的离职手续材料。

6. 法律援助:如果公司认为您的离职行为违反了相关的劳动法规或公司的规章制度,他们可能会提起诉讼请求法院裁决您承担违约责任并停止支付相关补偿或费用。

请注意,以上只是可能的赔偿方案,具体的赔偿数额应根据法律程序和公司的具体规定进行计算。建议您在收到解除劳动合同通知书时,尽快联系人力资源部门或者法律专业人士,以获取更准确的法律建议和处理方式。同时,也要了解自己的权利和义务,尽可能保留好相关证据,以便在后续过程中维护自己的权益。

可以看到,模型还是能理解你问的东西是什么,这种情况下就不需要微调了。只不过回复看起来并不那么专业,我们希望专业的回复,应该能直接给出赔偿方案,以及相关的法律依据,这个时候就比较适合用 RAG 来实现。

当然,真实工作中,模型不可能选那么小的,一般都在 32B 以上。另外,这两年的大模型能力非常强,它们在预训练的时候,训练数据集也必然包含《劳动法》和《劳动合同法》,不需要我们自己插入这方面的知识库。我们这边只是为了演示RAG如何构建知识库,重要的是流程,而不是效果。

1.2 知识文档

目前国内的劳动纠纷涉及的法律只有两部,分别是《劳动法》和《劳动合同法》。

我国的《劳动法》是在1994年颁布实施,经过两次修正,《劳动合同法》是在2007年颁布,经过一次修正,这两款法律我们都使用最新的版本。

2 文档解析

法律条款的文档处理起来还是比较简单的,一是因为法律条款中全部都是文字,没有表格、图片、流程图等复杂的数据格式,二是因为每个条款都是单独起一段,像下面一样,格式非常整齐:

在这里插入图片描述

我们把《劳动法》和《劳动合同法》中的每一条都拎出来,然后做成 JSON 文件,形式如下:

file1.json

在这里插入图片描述

file2.json

在这里插入图片描述

之所以要把每一条单独拎出来,是因为我们想让每条称为一个单独的知识点。可以用 AI 模型生成一段代码,把PDF文件中的内容处理成上面的形式。

3 知识库构建

3.1 构建知识节点

import json
from pathlib import Path
from typing import List
from llama_index.core.schema import TextNode

def load_and_create_nodes(data_dir: str) -> List[TextNode]:
    """加载JSON法律文件并直接转换为TextNode节点"""
    json_files = list(Path(data_dir).glob("*.json"))
    assert json_files, f"未找到JSON文件于 {data_dir}"
    
    nodes = []
    total_entries = 0
    
    for json_file in json_files:
        with open(json_file, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                # 验证数据结构
                if not isinstance(data, list):
                    raise ValueError(f"文件 {json_file.name} 根元素应为列表")
                
                for item in data:
                    if not isinstance(item, dict):
                        raise ValueError(f"文件 {json_file.name} 包含非字典元素")
                    
                    for k, v in item.items():
                        if not isinstance(v, str):
                            raise ValueError(f"文件 {json_file.name} 中键 '{k}' 的值不是字符串")
                    
                    # 处理字典中的键值对 (每个item只有一个键值对)
                    for full_title, content in item.items():
                        # 生成稳定ID (文件 + 标题)
                        node_id = f"{json_file.name}::{full_title}"
                        
                        # 解析法律名称和条款号
                        parts = full_title.split(" ", 1)
                        law_name = parts[0] if len(parts) > 0 else "未知法律"
                        article = parts[1] if len(parts) > 1 else "未知条款"
                        
                        # 创建TextNode节点
                        node = TextNode(
                            text=content,
                            id_=node_id,
                            metadata={
                                "law_name": law_name,
                                "article": article,
                                "full_title": full_title,
                                "source_file": json_file.name,
                                "content_type": "legal_article"
                            }
                        )
                        nodes.append(node)
                        total_entries += 1
            
            except Exception as e:
                raise RuntimeError(f"处理文件 {json_file} 失败: {str(e)}")
    
    print(f"成功转换 {total_entries} 个法律条款为文本节点")
    if nodes:
        print(f"id示例:{nodes[0].id_}")
        print(f"文本示例:{nodes[0].text}")
        print(f"元数据示例:{nodes[0].metadata}")
    
    return nodes

if __name__ == "__main__":
    nodes = load_and_create_nodes("/data/coding/data")

输出:

成功转换 205 个法律条款为文本节点
id示例:file1.json::中华人民共和国劳动法 第一条
文本示例:为了保护劳动者的合法权益,调整劳动关系,建立和维护适应社会主义市场经济的劳动制度,促进经济发展和社会进步,根据宪法,制定本法。
元数据示例:{'law_name': '中华人民共和国劳动法', 'article': '第一条', 'full_title': '中华人民共和国劳动法 第一条', 'source_file': 'file1.json', 'content_type': 'legal_article'}

3.2 嵌入向量初始化

初始化嵌入向量的代码比较简单:

class Config:
    EMBED_MODEL_PATH = "/data/coding/models/sungw111/text2vec-base-chinese-sentence"
    DATA_DIR = "/data/coding/data"
    VECTOR_DB_DIR = "/data/coding/chroma_db"
    PERSIST_DIR = "/data/coding/storage"
    
    COLLECTION_NAME = "chinese_labor_laws"
    TOP_K = 3


from llama_index.embeddings.huggingface import HuggingFaceEmbedding
def init_embedding_model():
    """初始化模型并验证"""
    # Embedding模型
    embed_model = HuggingFaceEmbedding(
        model_name=Config.EMBED_MODEL_PATH,
        # 在一些比较老版本的 llama-index-embeddings-huggingface 中,需要加下面的参数,当前版本(0.5.4)不需要
        # encode_kwargs = {
        #     'normalize_embeddings': True,
        #     'device': 'cuda' if hasattr(Settings, 'device') else 'cpu'
        # }
    )
    Settings.embed_model = embed_model
 
    # 验证模型
    test_embedding = embed_model.get_text_embedding("测试文本")
    print(f"Embedding维度验证:{len(test_embedding)}")
    
    return embed_model

3.2 向量存储

import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.core import VectorStoreIndex, StorageContext, Settings


class Config:
    EMBED_MODEL_PATH = "/data/coding/models/sungw111/text2vec-base-chinese-sentence"
    DATA_DIR = "/data/coding/data"
    VECTOR_DB_DIR = "/data/coding/chroma_db"
    PERSIST_DIR = "/data/coding/storage"
    
    COLLECTION_NAME = "chinese_labor_laws"
    TOP_K = 3


def init_vector_store(nodes: List[TextNode]) -> VectorStoreIndex:
    chroma_client = chromadb.PersistentClient(path=Config.VECTOR_DB_DIR)

    # 创建或者获取集合(首次运行是创建,第二次运行则是获取)
    chroma_collection = chroma_client.get_or_create_collection(
        name=Config.COLLECTION_NAME,
        metadata={"hnsw:space": "cosine"}
    )

    # 判断是否需要新建索引
    if chroma_collection.count() == 0 and nodes is not None:
        print(f"创建新索引({len(nodes)}个节点)...")

        # 创建存储上下文
        storage_context = StorageContext.from_defaults(
            # 将 ChromaDB 的集合(collection)封装为 LlamaIndex 可识别的向量存储接口,以支持索引构建与查询。
            # 后续通过 VectorStoreIndex 构建索引时,会使用该 ChromaVectorStore 实例来添加或搜索向量。
            vector_store=ChromaVectorStore(chroma_collection=chroma_collection) 
        )
        # 创建 StorageContext 对象的作用是为 LlamaIndex 提供一个统一的数据存储管理上下文,
        # 用于协调向量存储(vector store)、文档存储(docstore)和索引之间的数据流动与持久化操作。
        
        # 将文本节点存入文档存储(元数据+文本内容)
        storage_context.docstore.add_documents(nodes)  
        
        # 创建索引,将节点向量化并创建可搜索的索引结构
        index = VectorStoreIndex(
            nodes,
            storage_context=storage_context,
            show_progress=True
        )
        # 在创建 VectorStoreIndex 对象时需要传入该 StorageContext 对象,以确保索引知道如何访问向量和文档。

        # 双重持久化保障,将存储上下文和索引对象保存到 Config.PERSIST_DIR 目录(双重保证)
        storage_context.persist(persist_dir=Config.PERSIST_DIR)
        index.storage_context.persist(persist_dir=Config.PERSIST_DIR) 
    else:
        print("加载已有索引...")

        # 加载存储上下文,从持久化目录加载已有状态
        storage_context = StorageContext.from_defaults(
            persist_dir=Config.PERSIST_DIR,
            vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
        )

        # 构建索引对象,基于已有向量存储重建内存索引结构
        index = VectorStoreIndex.from_vector_store(
            storage_context.vector_store,
            storage_context=storage_context,
            embed_model=Settings.embed_model
        )

    # 安全验证
    print("\n存储验证结果:")
    doc_count = len(storage_context.docstore.docs)
    print(f"DocStore记录数:{doc_count}")
    
    if doc_count > 0:
        sample_key = next(iter(storage_context.docstore.docs.keys()))
        print(f"示例节点ID:{sample_key}")
    else:
        print("警告:文档存储为空,请检查节点添加逻辑!")
    
    
    return index
    

if __name__ == "__main__":
    nodes = load_and_create_nodes("/data/coding/data")

    # 初始化模型
    embed_model = init_embedding_model()
    print()

    import time 
    print("\n初始化向量存储...")
    start_time = time.time()
    index = init_vector_store(nodes)
    print(f"索引加载耗时:{time.time()-start_time:.2f}s")

上面的代码看不懂不要紧,我们只需要用的时候会改配置套代码就行。

输出:

成功转换 205 个法律条款为文本节点
id示例:file1.json::中华人民共和国劳动法 第一条
文本示例:为了保护劳动者的合法权益,调整劳动关系,建立和维护适应社会主义市场经济的劳动制度,促进经济发展和社会进步,根据宪法,制定本法。
元数据示例:{'law_name': '中华人民共和国劳动法', 'article': '第一条', 'full_title': '中华人民共和国劳动法 第一条', 'source_file': 'file1.json', 'content_type': 'legal_article'}

Embedding维度验证:768

初始化向量存储...
创建新索引(205个节点)...
Generating embeddings: 100%|████████████████████████████████████████████████████| 205/205 [00:01<00:00, 111.22it/s]

存储验证结果:
DocStore记录数:205
示例节点ID:file1.json::中华人民共和国劳动法 第一条
索引加载耗时:2.54s

程序运行结束后,在当前目录下可以看大两个新的文件夹,即向量数据库目录存储上下文与索引目录,如下图所示:
在这里插入图片描述

storage 目录下有4个json文件,其中 docstore.json 存储了文档节点的元数据和文本内容(包括节点ID、文本内容、元数据),index_store.json 存储了LlamaIndex 索引的元数据信息(如索引结构、索引ID),graph_store.json 存储了图索引数据(如关系图谱的边、节点关系),image__vector_store.json 则保存了图像嵌入向量,后三个了解即可。

这里最重要的是 docstore.json,内容如下:
在这里插入图片描述

这里用了utf-8编码,我们看不见,可以写一个脚本打印一下:

import json

def load_and_print_json(file_path):
    # 读取JSON文件
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
        # print(data)
        
    # 遍历数据并打印中文内容
    for key, value in data['docstore/data'].items():
        print(key)
        print(value)
        print('-'*80)
        
        # 获取中文标题
        full_title = value['__data__']['metadata']['full_title']
        # 获取文本内容
        text = value['__data__']['text']
        print(f"标题:{full_title}")
        print(f"内容:{text}\n")
        break

    # # 将JSON数据格式化并打印
    # formatted_json = json.dumps(data, ensure_ascii=False, indent=4)
    # print(formatted_json)

# 使用函数
file_path = '/data/coding/storage/docstore.json'  
load_and_print_json(file_path)

输出:

ile1.json::中华人民共和国劳动法 第一条
{'__data__': {'id_': 'file1.json::中华人民共和国劳动法 第一条', 'embedding': None, 'metadata': {'law_name': '中华人民共和国劳动法', 'article': '第一条', 'full_title': '中华人民共和国劳动法 第一条', 'source_file': 'file1.json', 'content_type': 'legal_article'}, 'excluded_embed_metadata_keys': [], 'excluded_llm_metadata_keys': [], 'relationships': {}, 'metadata_template': '{key}: {value}', 'metadata_separator': '\n', 'text': '为了保护劳动者的合法权益,调整劳动关系,建立和维护适应社会主义市场经济的劳动制度,促进经济发展和社会进步,根据宪法,制定本法。', 'mimetype': 'text/plain', 'start_char_idx': None, 'end_char_idx': None, 'metadata_seperator': '\n', 'text_template': '{metadata_str}\n\n{content}', 'class_name': 'TextNode'}, '__type__': '1'}
--------------------------------------------------------------------------------
标题:中华人民共和国劳动法 第一条
内容:为了保护劳动者的合法权益,调整劳动关系,建立和维护适应社会主义市场经济的劳动制度,促进经济发展和社会进步,根据宪法,制定本法。

4 查询

4.1 初始化大模型

初始化大模型的代码如下:

import time
from llama_index.core import PromptTemplate
from llama_index.llms.huggingface import HuggingFaceLLM

Config.LLM_MODEL_PATH = "/data/coding/models/Qwen/Qwen1.5-1.8B-Chat"

def init_llm_model():
    # 初始化大语言模型
    llm = HuggingFaceLLM(
        model_name=Config.LLM_MODEL_PATH,
        tokenizer_name=Config.LLM_MODEL_PATH,
        model_kwargs={"trust_remote_code": True},
        tokenizer_kwargs={"trust_remote_code": True},
        generate_kwargs={"temperature": 0.3}        # 要让回答偏向于知识库,要让模型减少随机性,因此把temperature设置低一些,不要高于0.3
    )

    Settings.llm = llm
    
    return llm

4.2 模型响应

模型响应主要就是创建一个查询引擎,然后进行查询,我们里把前面的步骤都串起来,这样对整个过程能有更好的理解:


def main():
    embed_model, llm = init_embedding_model(), init_llm_model()

    # 仅当需要更新数据时执行
    if not Path(Config.VECTOR_DB_DIR).exists():
        print("\n初始化数据...")
        nodes = load_and_create_nodes(Config.DATA_DIR)
    else:
        nodes = None  # 已有数据时不加载

    # 初始化向量存储
    print("\n初始化向量存储...")
    start_time = time.time()
    index = init_vector_store(nodes)
    print(f"索引加载耗时:{time.time()-start_time:.2f}s")

    # 创建查询引擎
    query_engine = index.as_query_engine(
        similarity_top_k=Config.TOP_K,
        # text_qa_template=response_template,
        verbose=True
    )

    # 示例查询
    while True:
        question = input("\n请输入劳动法相关问题(输入q退出): ")
        if question.lower() == 'q':
            break
        
        # 执行查询
        response = query_engine.query(question)
        
        # 显示结果
        print(f"\n智能助手回答:\n{response.response}")
        print("\n支持依据:")
        for idx, node in enumerate(response.source_nodes, 1):
            meta = node.metadata
            print(f"\n[{idx}] {meta['full_title']}")
            print(f"  来源文件:{meta['source_file']}")
            print(f"  法律名称:{meta['law_name']}")
            print(f"  条款内容:{node.text[:100]}...")
            print(f"  相关度得分:{node.score:.4f}")


if __name__ == "__main__":
    main()

输出:

Embedding维度验证:768

初始化向量存储...
加载已有索引...
Loading llama_index.core.storage.kvstore.simple_kvstore from /data/coding/storage/docstore.json.
Loading llama_index.core.storage.kvstore.simple_kvstore from /data/coding/storage/index_store.json.

存储验证结果:
DocStore记录数:205
示例节点ID:file1.json::中华人民共和国劳动法 第一条
索引加载耗时:0.63s

请输入劳动法相关问题(输入q退出): 劳动合同试用期最长可以多久?

智能助手回答:
试用期最长不得超过六个月。

支持依据:

[1] 中华人民共和国劳动法 第二十一条
  来源文件:file1.json
  法律名称:中华人民共和国劳动法
  条款内容:劳动合同可以约定试用期。试用期最长不得超过六个月。...
  相关度得分:0.9302

[2] 中华人民共和国劳动合同法 第十九条
  来源文件:file2.json
  法律名称:中华人民共和国劳动合同法
  条款内容:劳动合同期限三个月以上不满一年的,试用期不得超过一个月;劳动合同期限一年以上不满三年的,试用期不得超过二个月;三年以上固定期限和无固定期限的劳动合同,试用期不得超过六个月。
同一用人单位与同一劳动者只...
  相关度得分:0.9180

[3] 中华人民共和国劳动法 第二十条
  来源文件:file1.json
  法律名称:中华人民共和国劳动法
  条款内容:劳动合同的期限分为有固定期限、无固定期限和以完成一定的工作为期限。
劳动者在同一用人单位连续工作满十年以上,当事人双方同意续延劳动合同的,如果劳动者提出订立无固定期限的劳动合同,应当订立无固定期限的劳...
  相关度得分:0.8955

请输入劳动法相关问题(输入q退出): 

这里有点乱,下面的截图看的比较清楚:
在这里插入图片描述

可以看到,模型给出的回复相当靠谱,我们还把相似度前三的节点内容给打印出来了,作为回答的依据。但是,我们可以看到,第三条依据似乎和我们的输入并不相关,这是单纯采用余弦相似度会产生的问题。

4.2 本文程序存在的问题

我们再试一下方案选型时使用的问题:我在目前这家公司工作三年了,现在公司想开除我,请问会如何赔偿我?
结果如下:

在这里插入图片描述
在这里插入图片描述

可以看到,三条依据中,有两条并不是我们想要的(红框框出),第一条讲的是 “劳动者解除劳动合同” 的情形,第三条讲的是被开除之后工会能给出什么样的帮助。这里第一条不相关,我估计是嵌入模型分不清 “劳动者解除劳动合同” 和 “用人单位解除劳动合同” 的区别,把“员工辞职”和“单位开除员工”认为是两件相似的事情,分不清主动和被动。第三条讲的倒是被开除的事情,但没讲赔偿,也就是说,相关但答非所问。

另外,模型并没有能输出赔偿方案,我们从输出结果可以看到,模型没有答完,有可能是没回复完就停了,也有可能是输出完了,但没打印完,我估计是 llama_index.llms.huggingface 的相关包存在 Bug。

再来看一个例子:
在这里插入图片描述
在这里插入图片描述

这里只有第二个依据是我们需要的,依据[1] 是讲公司在什么情况下可以开除员工的,依据[3] 讲的是被开除时,工会能提供什么帮助。我们的程序并没有把《劳动法》中的原文打印完整,下面是完整原文:

  1. “中华人民共和国劳动法 第二十六条”: "有下列情形之一的,用人单位可以解除劳动合同,但是应当提前三十日以书面形式通知劳动者本人:
  • (一)劳动者患病或者非因工负伤,医疗期满后,不能从事原工作也不能从事由用人单位另行安排的工作的;
  • (二)劳动者不能胜任工作,经过培训或者调整工作岗位,仍不能胜任工作的;
  • (三)劳动合同订立时所依据的客观情况发生重大变化,致使原劳动合同无法履行,经当事人协商不能就变更劳动合同达成协议的。",
  1. “中华人民共和国劳动法 第二十四条”: “经劳动合同当事人协商一致,劳动合同可以解除。”

  2. “中华人民共和国劳动法 第三十条”: “用人单位解除劳动合同,工会认为不适当的,有权提出意见。如果用人单位违反法律、法规或者劳动合同,工会有权要求重新处理;劳动者申请仲裁或者提起诉讼的,工会应当依法给予支持和帮助。”

明明检索的结果不正确,大模型依然按照这个来回答了。

既然给出的依据和我们的问题无关,那为何算出的相关度得分会这么高?主要原因是问题和知识节点的内容都是中文,并且都是语义通顺的句子,只要满足这两点,相似度(相关度得分)都能达到0.7以上,接近0.9都有可能。我们来问一个技术问题看看:

在这里插入图片描述

总结
通过前面的例子,可以看到,我们当前的系统存在以下几个问题:

    1. 嵌入模型分不清主动和被动的区别;
    1. 检索出的结果相关,但答非所问;
    1. 与问题不相关的节点,相似度(相关度得分)却很高;
    1. 模型的回复不完整,没回复完就停了,也有可能是回复完了,但打印的不完整。

这些问题我们会在下一篇文章中解决,本文重点是介绍流程。

完整代码

将本文中的代码整个整理,整理后的完整代码如下:

import time
import json
from typing import List
from pathlib import Path

from llama_index.core import PromptTemplate
from llama_index.core.schema import TextNode
from llama_index.core import VectorStoreIndex, StorageContext, Settings

import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.llms.huggingface import HuggingFaceLLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding


class Config:
    EMBED_MODEL_PATH = "/data/coding/models/sungw111/text2vec-base-chinese-sentence"
    LLM_MODEL_PATH = "/data/coding/models/Qwen/Qwen1.5-1.8B-Chat"
    DATA_DIR = "/data/coding/data"
    VECTOR_DB_DIR = "/data/coding/chroma_db"
    PERSIST_DIR = "/data/coding/storage"
    
    COLLECTION_NAME = "chinese_labor_laws"
    TOP_K = 3


def load_and_create_nodes(data_dir: str) -> List[TextNode]:
    """加载JSON法律文件并直接转换为TextNode节点"""
    json_files = list(Path(data_dir).glob("*.json"))
    assert json_files, f"未找到JSON文件于 {data_dir}"
    
    nodes = []
    total_entries = 0
    
    for json_file in json_files:
        with open(json_file, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                # 验证数据结构
                if not isinstance(data, list):
                    raise ValueError(f"文件 {json_file.name} 根元素应为列表")
                
                for item in data:
                    if not isinstance(item, dict):
                        raise ValueError(f"文件 {json_file.name} 包含非字典元素")
                    
                    for k, v in item.items():
                        if not isinstance(v, str):
                            raise ValueError(f"文件 {json_file.name} 中键 '{k}' 的值不是字符串")
                    
                    # 处理字典中的键值对 (每个item只有一个键值对)
                    for full_title, content in item.items():
                        # 生成稳定ID (文件 + 标题)
                        node_id = f"{json_file.name}::{full_title}"
                        
                        # 解析法律名称和条款号
                        parts = full_title.split(" ", 1)
                        law_name = parts[0] if len(parts) > 0 else "未知法律"
                        article = parts[1] if len(parts) > 1 else "未知条款"
                        
                        # 创建TextNode节点
                        node = TextNode(
                            text=content,
                            id_=node_id,
                            metadata={
                                "law_name": law_name,
                                "article": article,
                                "full_title": full_title,
                                "source_file": json_file.name,
                                "content_type": "legal_article"
                            }
                        )
                        nodes.append(node)
                        total_entries += 1
            
            except Exception as e:
                raise RuntimeError(f"处理文件 {json_file} 失败: {str(e)}")
    
    print(f"成功转换 {total_entries} 个法律条款为文本节点")
    if nodes:
        print(f"id示例:{nodes[0].id_}")
        print(f"文本示例:{nodes[0].text}")
        print(f"元数据示例:{nodes[0].metadata}")
    
    return nodes


def init_vector_store(nodes: List[TextNode]) -> VectorStoreIndex:
    chroma_client = chromadb.PersistentClient(path=Config.VECTOR_DB_DIR)

    # 创建或者获取集合(首次运行是创建,第二次运行则是获取)
    chroma_collection = chroma_client.get_or_create_collection(
        name=Config.COLLECTION_NAME,
        metadata={"hnsw:space": "cosine"}
    )

    # 判断是否需要新建索引
    if chroma_collection.count() == 0 and nodes is not None:
        print(f"创建新索引({len(nodes)}个节点)...")

        # 创建存储上下文
        storage_context = StorageContext.from_defaults(
            # 将 ChromaDB 的集合(collection)封装为 LlamaIndex 可识别的向量存储接口,以支持索引构建与查询。
            # 后续通过 VectorStoreIndex 构建索引时,会使用该 ChromaVectorStore 实例来添加或搜索向量。
            vector_store=ChromaVectorStore(chroma_collection=chroma_collection) 
        )
        # 创建 StorageContext 对象的作用是为 LlamaIndex 提供一个统一的数据存储管理上下文,
        # 用于协调向量存储(vector store)、文档存储(docstore)和索引之间的数据流动与持久化操作。
        
        # 将文本节点存入文档存储(元数据+文本内容)
        storage_context.docstore.add_documents(nodes)  
        
        # 创建索引,将节点向量化并创建可搜索的索引结构
        index = VectorStoreIndex(
            nodes,
            storage_context=storage_context,
            show_progress=True
        )
        # 在创建 VectorStoreIndex 对象时需要传入该 StorageContext 对象,以确保索引知道如何访问向量和文档。

        # 双重持久化保障,将存储上下文和索引对象保存到 Config.PERSIST_DIR 目录(双重保证)
        storage_context.persist(persist_dir=Config.PERSIST_DIR)
        index.storage_context.persist(persist_dir=Config.PERSIST_DIR) 
    else:
        print("加载已有索引...")

        # 加载存储上下文,从持久化目录加载已有状态
        storage_context = StorageContext.from_defaults(
            persist_dir=Config.PERSIST_DIR,
            vector_store=ChromaVectorStore(chroma_collection=chroma_collection)
        )

        # 构建索引对象,基于已有向量存储重建内存索引结构
        index = VectorStoreIndex.from_vector_store(
            storage_context.vector_store,
            storage_context=storage_context,
            embed_model=Settings.embed_model
        )

    # 安全验证
    print("\n存储验证结果:")
    doc_count = len(storage_context.docstore.docs)
    print(f"DocStore记录数:{doc_count}")
    
    if doc_count > 0:
        sample_key = next(iter(storage_context.docstore.docs.keys()))
        print(f"示例节点ID:{sample_key}")
    else:
        print("警告:文档存储为空,请检查节点添加逻辑!")
    
    
    return index


def init_embedding_model():
    # 初始化Embedding模型
    embed_model = HuggingFaceEmbedding(
        model_name=Config.EMBED_MODEL_PATH,
        # 在一些比较老版本的 llama-index-embeddings-huggingface 中,需要加下面的参数,当前版本(0.5.4)不需要
        # encode_kwargs = {
        #     'normalize_embeddings': True,
        #     'device': 'cuda' if hasattr(Settings, 'device') else 'cpu'
        # }
    )
    Settings.embed_model = embed_model
 
    # 验证模型
    test_embedding = embed_model.get_text_embedding("测试文本")
    print(f"Embedding维度验证:{len(test_embedding)}")
    
    return embed_model


def init_llm_model():
    # 初始化大语言模型
    llm = HuggingFaceLLM(
        model_name=Config.LLM_MODEL_PATH,
        tokenizer_name=Config.LLM_MODEL_PATH,
        model_kwargs={"trust_remote_code": True},
        tokenizer_kwargs={"trust_remote_code": True},
        generate_kwargs={"temperature": 0.3}        # 要让回答偏向于知识库,要让模型减少随机性,因此把temperature设置低一些,不要高于0.3
    )

    Settings.llm = llm
    
    return llm


def main():
    embed_model, llm = init_embedding_model(), init_llm_model()

    # 仅当需要更新数据时执行
    if not Path(Config.VECTOR_DB_DIR).exists():
        print("\n初始化数据...")
        nodes = load_and_create_nodes(Config.DATA_DIR)
    else:
        nodes = None  # 已有数据时不加载

    # 初始化向量存储
    print("\n初始化向量存储...")
    start_time = time.time()
    index = init_vector_store(nodes)
    print(f"索引加载耗时:{time.time()-start_time:.2f}s")

    # 创建查询引擎
    query_engine = index.as_query_engine(
        similarity_top_k=Config.TOP_K,
        # text_qa_template=response_template,
        verbose=True
    )

    # 示例查询
    while True:
        question = input("\n请输入劳动法相关问题(输入q退出): ")
        if question.lower() == 'q':
            break
        
        # 执行查询
        response = query_engine.query(question)
        
        # 显示结果
        print(f"\n智能助手回答:\n{response.response}")
        print("\n支持依据:")
        for idx, node in enumerate(response.source_nodes, 1):
            meta = node.metadata
            print(f"\n[{idx}] {meta['full_title']}")
            print(f"  来源文件:{meta['source_file']}")
            print(f"  法律名称:{meta['law_name']}")
            print(f"  条款内容:{node.text[:100]}...")
            print(f"  相关度得分:{node.score:.4f}")


if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值