上一讲里我们看到大模型的确有效。在进行情感分析的时候,我们通过 OpenAI 的 API 拿到的 Embedding,比 T5-base 这样单机可以运行的小模型,效果还是好很多的。
不过,我们之前选用的问题的确有点太简单了。我们把 5 个不同的分数分成了正面、负面和中性,还去掉了相对难以判断的“中性”评价,这样我们判断的准确率高的确是比较好实现的。但如果我们想要准确地预测出具体的分数呢?
利用 Embedding,训练机器学习模型
最简单的办法就是利用我们拿到的文本 Embedding 的向量。这一次,我们不直接用向量之间的距离,而是使用传统的机器学习的方法来进行分类。毕竟,如果只是用向量之间的距离作为衡量标准,就没办法最大化地利用已经标注好的分数信息了。
事实上,OpenAI 在自己的官方教程里也直接给出了这样一个例子。在这里也放上了相应的 GitHub 的代码链接,可以去看一下。不过,为了避免 OpenAI 王婆卖瓜自卖自夸,我们也希望能和其他人用传统的机器学习方式得到的结果做个比较。
因此重新找了一个中文的数据集来试一试。这个数据集是在中文互联网上比较容易找到的一份今日头条的新闻标题和新闻关键词,在 GitHub 上可以直接找到数据,把链接也放在这里。用这个数据集的好处是,有人同步放出了预测的实验效果。我们可以拿自己训练的结果和他做个对比。
数据处理,小坑也不少
在训练模型之前,我们要先获取每一个新闻标题的 Embedding。我们通过 Pandas 这个 Python 数据处理库,把对应的文本加载到内存里。接着去调用之前我们使用过的 OpenAI 的 Embedding 接口,然后把返回结果一并存下来就好了。这个听起来非常简单直接,我也把对应的代码先放在下面,不过你先别着急运行。
注:因为后面的代码可能会耗费比较多的 Token 数量,如果你使用的是免费的 5 美元额度的话,可以直接去拿放在 Github 里的数据文件,用已经处理好的数据。
import pandas as pd
import tiktoken
import openai
import os
from openai.embeddings_utils import get_embedding, get_embeddings
openai.api_key = os.environ.get("OPENAI_API_KEY")
# embedding model parameters
embedding_model = "text-embedding-ada-002"
embedding_encoding = "cl100k_base" # this the encoding for text-embedding-ada-002
max_tokens = 8000 # the maximum for text-embedding-ada-002 is 8191
# import data/toutiao_cat_data.txt as a pandas dataframe
df = pd.read_csv('data/toutiao_cat_data.txt', sep='_!_', names=['id', 'code', 'category', 'title', 'keywords'])
df = df.fillna("")
df["combined"] = (
"标题: " + df.title.str.strip() + "; 关键字: " + df.keywords.str.strip()
)
print("Lines of text before filtering: ", len(df))
encoding = tiktoken.get_encoding(embedding_encoding)
# omit reviews that are too long to embed
df["n_tokens"] = df.combined.apply(lambda x: len(encoding.encode(x)))
df = df[df.n_tokens <= max_tokens]
print("Lines of text after filtering: ", len(df))
注:这个是加载数据并做一些简单预处理的代码,可以直接运行。
此外,我们也需要和前几讲的代码一样,定义一个 get_embedding 的函数,方便后面调用。这个函数原本在早期的 openai 库里是直接提供的,但是随着 API 的不断更新,这个库已经被移除了,不过代码非常简单,我们自己来定义一下就好。
from openai import OpenAI
import os
client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
EMBEDDING_MODEL = "text-embedding-ada-002"
def get_embeddin