大模型技术日新月异的今天,文本分类作为自然语言处理领域的基础且关键的应用,正迎来新的变革。从过滤垃圾邮件、识别产品类别,到理解聊天机器人中的用户意图,文本分类无处不在。传统方法依赖于大量标注数据训练定制化的机器学习模型,而大模型的出现,使得零样本或少样本分类成为可能,极大地缩短了服务部署时间。然而,这种方法的准确性往往低于定制模型,且高度依赖于精巧的提示工程。本文将深入探讨一种名为检索增强分类 (Retrieval Augmented Classification, RAC) 的方法,旨在弥合定制模型与通用大模型之间的差距,同时减少提示工程的工作量,提升文本分类的效率与精度。

大模型 vs. 定制机器学习模型:文本分类的优劣对比

在深入探讨 检索增强分类 之前,我们首先对比一下大模型与定制机器学习模型在文本分类任务中的优缺点。

大模型作为通用分类器:

  • 优点:
    • 强大的泛化能力: 基于海量预训练语料库和强大的推理能力,大模型 具备卓越的泛化能力。
    • 多任务处理: 单个通用 大模型 可以处理多个分类任务,无需为每个任务单独部署模型,大大简化了部署流程。
    • 持续提升: 随着 大模型 技术的不断进步,只需采用更新、更强大的模型,即可轻松提升分类准确性。
    • 易于部署: 众多 大模型 以托管服务的形式提供,降低了部署的技术门槛和所需精力。
    • 小样本优势: 在标注数据有限或获取成本高昂的情况下,大模型 的性能通常优于定制机器学习模型。
    • 多语言支持: 大模型 能够处理多种语言的文本分类任务。
    • 按需付费: 对于预测量较低或不稳定的场景,按 token 计费的模式可以降低成本。
    • 动态类别定义: 无需重新训练模型,只需修改提示,即可动态更改类别定义。
  • 缺点:
    • 幻觉问题: 大模型 容易产生幻觉,输出不准确或不存在的信息。
    • 速度较慢: 相比于小型定制模型,大模型 的处理速度通常较慢。
    • 提示工程: 需要投入精力进行提示工程,以优化分类效果。
    • 吞吐量限制: 高吞吐量的应用在使用 大模型 即服务时,可能很快遇到配额限制。
    • 类别数量限制: 当潜在类别数量过多时,由于上下文长度限制,大模型 的效果会下降。定义所有类别会占用大量可用上下文。
    • 大数据劣势: 在拥有大量标注数据的情况下,大模型 的准确性通常低于定制模型。

定制机器学习模型:

  • 优点:
    • 高效快速: 定制模型通常更加高效快速。
    • 灵活可控: 在架构选择、训练和部署方法方面具有更高的灵活性。
    • 可解释性: 能够为模型添加可解释性和不确定性估计功能。
    • 大数据优势: 在拥有大量标注数据的情况下,准确性更高。
    • 完全掌控: 用户可以完全掌控模型和部署基础设施。
  • 缺点:
    • 频繁重训练: 需要频繁重训练,以适应新的数据或分布变化。
    • 数据依赖性: 可能需要大量的标注数据。
    • 泛化能力有限: 泛化能力有限,对领域外的词汇或表达方式较为敏感。
    • MLOps 需求: 需要具备 MLOps 知识才能进行部署。

检索增强分类 (RAC):融合 RAG 与少样本学习

为了兼顾 大模型 的优点,同时缓解其缺点,我们借鉴 RAG (Retrieval Augmented Generation, 检索增强生成) 的思想,并结合 少样本提示 (Few-shot prompting) 技术,提出了 检索增强分类 (RAC) 方法。

RAG: 在提出问题之前,利用外部知识增强 大模型 的上下文,降低幻觉发生的可能性,并提高回答质量。

少样本提示: 在分类任务中,向 大模型 展示输入和预期输出的示例,帮助其理解任务。

RAC 的核心思想是:动态地从知识库中检索与待分类文本最相似的示例,并将这些示例作为少样本提示注入到 大模型 的输入上下文中。同时,利用检索到的近邻示例限定潜在类别的范围。这样既减少了 大模型 产生幻觉的可能性,又能在处理具有大量潜在类别的分类问题时,有效节省上下文长度。

举例来说,假设我们需要将新闻文章分类到不同的主题类别中(例如:政治、体育、科技、财经等)。传统的少样本提示可能需要在 prompt 中列出所有可能的类别,这会占用大量的 token。而 RAC 方法则会先检索与待分类文章相似的几篇新闻文章,然后只将这些文章所属的类别作为 prompt 中的候选类别。例如,如果检索到的文章都属于“政治”和“财经”类别,那么 RAC 就会只让 大模型 从这两个类别中进行选择,从而有效地缩小了搜索空间,提高了分类效率和准确性。

RAC 的具体实现步骤

1. 构建标签数据知识库:

首先,我们需要构建一个包含标注文本/类别对的知识库。该知识库将作为 大模型 的外部知识来源。我们可以使用 ChromaDB 这种向量数据库来存储和索引这些文本。

from typing import List
from uuid import uuid4
from langchain_core.documents import Document
from chromadb import PersistentClient
from langchain_chroma import Chroma
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import torch
from tqdm import tqdm
from chromadb.config import Settings
#from retrieval_augmented_classification.logger import logger # 假设你有一个 logger

class DatasetVectorStore:
    """ChromaDB vector store for PublicationModel objects with SentenceTransformers embeddings."""
    def __init__(
        self,
        db_name: str = "retrieval_augmented_classification",  # Using db_name as collection name in Chroma
        collection_name: str = "classification_dataset",
        persist_directory: str = "chroma_db",  # Directory to persist ChromaDB
    ):
        self.db_name = db_name
        self.collection_name = collection_name
        self.persist_directory = persist_directory
        # Determine if CUDA is available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        #logger.info(f"Using device: {device}")
        print(f"Using device: {device}") # use print instead of logger if you don't have one
        self.embeddings = HuggingFaceBgeEmbeddings(
            model_name="BAAI/bge-small-en-v1.5",
            model_kwargs={"device": device},
            encode_kwargs={
                "device": device,
                "batch_size": 100,
            },  # Adjust batch_size as needed
        )
        # Initialize Chroma vector store
        self.client = PersistentClient(
            path=self.persist_directory, settings=Settings(anonymized_telemetry=False)
        )
        self.vector_store = Chroma(
            client=self.client,
            collection_name=self.collection_name,
            embedding_function=self.embeddings,
            persist_directory=self.persist_directory,
        )
    def add_documents(self, documents: List) -> None:
        """
        Add multiple documents to the vector store.
        Args:
            documents: List of dictionaries containing document data.  Each dict needs a "text" key.
        """
        local_documents = []
        ids = []
        for doc_data in documents:
            if not doc_data.get("id"):
                doc_data["id"] = str(uuid4())
            local_documents.append(
                Document(
                    page_content=doc_data["text"],
                    metadata={k: v for k, v in doc_data.items() if k != "text"},
                )
            )
            ids.append(doc_data["id"])
        batch_size = 100  # Adjust batch size as needed
        for i in tqdm(range(0, len(documents), batch_size)):
            batch_docs = local_documents[i : i + batch_size]
            batch_ids = ids[i : i + batch_size]
            # Chroma's add_documents doesn't directly support pre-defined IDs. Upsert instead.
            self._upsert_batch(batch_docs, batch_ids)
    def _upsert_batch(self, batch_docs: List[Document], batch_ids: List[str]):
        """Upsert a batch of documents into Chroma.  If the ID exists, it updates; otherwise, it creates."""
        texts = [doc.page_content for doc in batch_docs]
        metadatas = [doc.metadata for doc in batch_ids]
        self.vector_store.add_texts(texts=texts, metadatas=metadatas, ids=batch_ids)

    def search(self, query: str, k: int = 5) -> List[Document]:
        """Search documents by semantic similarity."""
        results = self.vector_store.similarity_search(query, k=k)
        return results

上述代码使用 HuggingFace 的 BGE-small-en-v1.5 模型来嵌入文本,然后将嵌入向量存储在 ChromaDB 中。你可以根据实际需求选择其他嵌入模型,例如 OpenAI、Gemini 或 Nebius 提供的模型。

2. 检索 K 近邻:

当收到一条新的待分类文本时,我们首先使用相同的嵌入模型将其转换为向量,然后在知识库中搜索与其语义最相似的 K 个近邻示例。

def search(self, query: str, k: int = 5) -> List[Document]:
    """Search documents by semantic similarity."""
    results = self.vector_store.similarity_search(query, k=k)
    return results

3. 构建检索增强分类器:

接下来,我们构建 检索增强分类器 (RAC),它将检索到的 K 近邻示例作为少样本提示,并利用这些示例的类别信息来约束 大模型 的预测范围。

from typing import Optional
from pydantic import BaseModel, Field
from collections import Counter

#from retrieval_augmented_classification.vector_store import DatasetVectorStore # 假设你已经实现了 DatasetVectorStore
#from tenacity import retry, stop_after_attempt, wait_exponential # 假设你已经安装了 tenacity
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

class PredictedCategories(BaseModel):
    """
    Pydantic model for the predicted categories from the LLM.
    """
    reasoning: str = Field(description="Explain your reasoning")
    predicted_category: str = Field(description="Category")

class RAC:
    """
    A hybrid classifier combining K-Nearest Neighbors retrieval with an LLM for multi-class prediction.
    Finds top K neighbors, uses top few-shot for context, and uses all neighbor categories
    as potential prediction candidates for the LLM.
    """
    def __init__(
        self,
        vector_store: DatasetVectorStore,
        llm_client,
        knn_k_search: int = 30,
        knn_k_few_shot: int = 5,
    ):
        """
        Initializes the classifier.
        Args:
            vector_store: An instance of DatasetVectorStore with a search method.
            llm_client: An instance of the LLM client capable of structured output.
            knn_k_search: The number of nearest neighbors to retrieve from the vector store.
            knn_k_few_shot: The number of top neighbors to use as few-shot examples for the LLM.
                           Must be less than or equal to knn_k_search.
        """
        self.vector_store = vector_store
        self.llm_client = llm_client
        self.knn_k_search = knn_k_search
        self.knn_k_few_shot = knn_k_few_shot

    #@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=5)) # uncomment if you installed tenacity
    def predict(self, document_text: str) -> Optional[str]:
        """
        Predicts the relevant categories for a given document text using KNN retrieval and an LLM.
        Args:
            document_text: The text content of the document to classify.
        Returns:
            The predicted category
        """
        neighbors = self.vector_store.search(document_text, k=self.knn_k_search)
        all_neighbor_categories = set()
        valid_neighbors = []  # Store neighbors that have metadata and categories

        for neighbor in neighbors:
            if (
                hasattr(neighbor, "metadata")
                and isinstance(neighbor.metadata, dict)
                and "category" in neighbor.metadata
            ):
                all_neighbor_categories.add(neighbor.metadata["category"])
                valid_neighbors.append(neighbor)
            else:
                pass  # Suppress warnings for cleaner demo output

        if not valid_neighbors:
            return None

        category_counts = Counter(all_neighbor_categories)
        ranked_categories = [
            category for category, count in category_counts.most_common()
        ]

        if not ranked_categories:
            return None

        few_shot_neighbors = valid_neighbors[: self.knn_k_few_shot]

        messages = []

        system_prompt = f"""You are an expert multi-class classifier. Your task is to analyze the provided document text and assign the most relevant category from the list of allowed categories.You MUST only return categories that are present in the following list: {ranked_categories}.If none of the allowed categories are relevant, return an empty list.Return the categories by likelihood (more confident to least confident).Output your prediction as a JSON object matching the Pydantic schema: {PredictedCategories.model_json_schema()}."""

        messages.append(SystemMessage(content=system_prompt))

        for i, neighbor in enumerate(few_shot_neighbors):
            messages.append(
                HumanMessage(content=f"Document: {neighbor.page_content}")
            )
            expected_output_json = PredictedCategories(
                reasoning="Your reasoning here",
                predicted_category=neighbor.metadata["category"]
            ).model_dump_json()

            # Simulate the structure often used with tool calling/structured output
            ai_message_with_tool = AIMessage(
                content=expected_output_json,
            )
            messages.append(ai_message_with_tool)

        # Final user message: The document text to classify
        messages.append(HumanMessage(content=f"Document: {document_text}"))

        # Configure the client for structured output with the Pydantic schema
        structured_client = self.llm_client.with_structured_output(PredictedCategories)

        llm_response: PredictedCategories = structured_client.invoke(messages)
        predicted_category = llm_response.predicted_category

        return predicted_category if predicted_category in ranked_categories else None

在这个代码中,我们首先定义了 大模型 输出的结构,然后使用检索到的 K 近邻示例构建一个消息历史,模拟 大模型 给出了正确分类的场景。最后,我们将待分类文本作为最终的用户消息,并调用 大模型 进行预测。系统提示 (System Prompt) 中明确告知 大模型 只能从 K 近邻示例的类别中选择,有效地约束了预测范围。

4. 示例预测:

# 假设你已经初始化了 store 和 llm_client
#_rac = RAC(
#    vector_store=store,
#    llm_client=llm_client,
#    knn_k_search=50,
#    knn_k_few_shot=10,
#)

#print(
#    f"Initialized rac with knn_k_search={_rac.knn_k_search}, knn_k_few_shot={_rac.knn_k_few_shot}."
#)

#text = """Ivanoe Bonomi [iˈvaːnoe boˈnɔːmi] (18 October 1873 – 20 April 1951) was an Italian politician and statesman before and after World War II. Bonomi was born in Mantua. He was elected to the Italian Chamber of Deputies in ..."""
#category = _rac.predict(text)
#print(text)
#print(category)

#text = """Michel Rocard, né le 23 août 1930 à Courbevoie et mort le 2 juillet 2016 à Paris, est un haut fonctionnaire et ... """
#category = _rac.predict(text)
#print(text)
#print(category)

正如原文所展示的,即使输入文本是法语, RAC 依然能够给出正确的分类结果(“PrimeMinister”),这充分体现了该方法强大的泛化能力。

实验评估

为了验证 检索增强分类 的有效性,作者在 DBPedia Classes 数据集的 l3 类别上进行了实验评估。该数据集包含超过 200 个类别和 240000 个训练样本。实验结果表明,相比于简单的 KNN 分类器, RAC 的准确率提高了 9%。

具体而言,KNN分类器通过计算文本相似度,然后采用“多数投票”的方式来决定类别。RAC则在此基础上,通过 大模型 的推理能力,对KNN检索到的结果进行加权和修正,避免了KNN的绝对性,从而提高了准确率。

值得注意的是,在 Kaggle Notebooks 上,使用定制机器学习模型在该数据集的 l3 级别上获得的最佳准确率约为 94%。这意味着 RAC 在准确率方面仍然有提升空间,但它在灵活性和易用性方面具有显著优势。

结论与展望

检索增强分类 是一种利用“检索”来提升 大模型 文本分类 能力的有效方法。相比于传统的机器学习 文本分类 器, RAC 具有以下优势:

  • 无需重新训练: 可以动态更改训练数据集,无需重新训练模型,大大节省了时间和资源。
  • 强大的泛化能力: 借助 大模型 的推理能力和通用知识,具有更强的泛化能力,能够处理多种语言和领域的文本。
  • 易于部署: 使用托管 大模型 服务时,部署非常简单。
  • 多任务处理: 可以使用单个 大模型 处理多个分类任务。

当然, RAC 也存在一些局限性,例如:

  • 延迟较高: 相比于定制模型, RAC 的延迟较高,吞吐量较低。
  • 厂商锁定: 存在 大模型 供应商锁定的风险。

因此, RAC 并非解决所有文本分类问题的万能钥匙。但在以下场景中,它可能非常有用:

  • 需要频繁更新训练数据,但又不想每次都重新训练模型。
  • 标注数据有限,难以训练有效的定制模型。
  • 需要在短时间内搭建一个可用的文本分类服务。

未来,我们可以进一步探索以下方向来提升 RAC 的性能:

  • 优化检索策略: 采用更有效的检索算法,提高检索的准确率和效率。
  • 改进提示工程: 设计更有效的提示,引导 大模型 更好地完成分类任务。
  • 知识库构建: 研究如何更有效地构建和维护知识库,使其能够更好地支持 大模型 的推理。
  • 模型蒸馏:大模型 的知识蒸馏到小型模型中,从而降低延迟,提高吞吐量。

总之,检索增强分类大模型 时代的 文本分类 提供了一种新的范式。随着 大模型 技术的不断发展, RAC 有望在更多领域发挥重要作用。