BERT模型的应用与微调实战
立即解锁
发布时间: 2025-08-31 02:05:57 阅读量: 9 订阅数: 16 AIGC 

# BERT模型的应用与微调实战
## 1. 生成文本嵌入
在处理特定的Yelp评论子集时,我们可以使用`sentence-transformers`库中的“all-mpnet-base-v2”模型来生成文本嵌入。该模型在处理句子和较长文本的语义搜索任务时表现出色。以下是具体的代码实现:
```python
dataset = dataset.map(embed, batch_size=batch_size, batched=True)
dataset.set_format(type='numpy', columns=['embedding'], output_all_columns=True)
```
通过以上代码,我们完成了对数据集的嵌入计算,并设置了输出格式。
## 2. 主题建模
### 2.1 构建主题
在计算完文本嵌入后,我们可以使用`BERTopic`库进行主题建模。以下是具体步骤:
```python
from bertopic import BERTopic
topic_model = BERTopic(n_gram_range=(1, 3))
topics, probs = topic_model.fit_transform(
dataset["text"],
np.array(dataset["embedding"])
)
print(f"Number of topics: {len(topic_model.get_topics())}")
```
在上述代码中,我们首先导入了`BERTopic`库,然后创建了一个`BERTopic`模型实例,并设置了`n_gram_range`参数。接着,我们使用`fit_transform`方法对数据集进行主题建模,并打印出主题的数量。
### 2.2 主题大小分布
为了了解每个主题下的评论数量,我们可以查看主题的大小分布。以下是具体代码:
```python
topic_sizes = topic_model.get_topic_freq()
topic_sizes
```
运行上述代码后,会输出一个包含主题大小信息的表格,表格中显示了每个主题的大小(即包含该主题的评论数量)。需要注意的是,主题ID为 -1 的主题对应着`HDBSCAN`算法输出的未分配簇,该簇包含了所有无法分配到其他簇的内容。一般情况下,这个簇可以忽略,但如果其规模过大,则可能意味着我们选择的参数不适合当前数据。我们可以查看未分配簇的内容,以验证其中是否包含不相关的词汇。
### 2.3 主题可视化
`BERTopic`库提供了多种可视化主题分布的方法,以下是两种常见的可视化方式:
#### 2.3.1 可视化前10大主题
```python
topic_model.visualize_barchart(top_n_topics=10, n_words=5, width=1000, height=800)
```
上述代码将绘制一个柱状图,展示前10大主题,每个主题显示前5个相关词汇。
#### 2.3.2 可视化前20大主题的余弦相似度
```python
topic_model.visualize_heatmap(top_n_topics=20, n_clusters=5)
```
此代码将生成一个热力图,展示前20大主题嵌入向量之间的余弦相似度,反映了主题之间的重叠程度。
### 2.4 主题内容查看
为了查看特定主题的大小和代表性词汇,我们可以定义一个辅助函数:
```python
def dump_topic_and_docs(text, topic_id):
print(f"{text} size: {topic_sizes['Count'][topic_id + 1]}\n")
n = len(topic_sizes) - 1
if topic_id != -1:
reviews = topic_model.get_representative_docs(topic_id)
for review in reviews:
print(review, "\n")
return topic_model.get_topic(topic_id)[:10]
```
以下是查看不同主题内容的示例代码:
```python
# 查看未分配簇
dump_topic_and_docs("Unassigned cluster", -1)
# 查看最大主题
dump_topic_and_docs("Largest topic", 0)
# 查看最小主题
dump_topic_and_docs("Smallest topic", n - 1)
# 查看中位数主题
dump_topic_and_docs("Median topic", n // 2)
```
通过上述代码,我们可以查看未分配簇、最大主题、最小主题和中位数主题的大小、代表性评论以及前10个相关词汇。
## 3. 案例研究:微调BERT模型进行情感分类
### 3.1 目标
本案例研究旨在提供一个逐步演示,展示如何微调标准的BERT模型以进行句子分类任务,这里我们选择情感分类作为示例任务。
### 3.2 数据、工具和库
我们选择了Google Play应用评论数据集,该数据集包含15,746个样本,分为负面、中性和正面三个类别。我们使用`Huggingface`的`transformer`库进行微调任务,并使用标准的Python数据科学库进行数据处理和可视化。
### 3.3 实验、结果和分析
#### 3.3.1 加载预训练模型和分词器
```python
PRE_TRAINED_MODEL_NAME = 'bert-base-cased'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
```
在上述代码中,我们指定了预训练模型的名称,并使用`BertTokenizer`从`transformers`库中加载预训练的分词器。
#### 3.3.2 分词和编码示例
```python
sample_txt = 'When was I last outside? I am stuck at home for 2 weeks.'
tokens = tokenizer.tokenize(sample_txt)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(f' Sentence: {sample_txt}')
print(f' Tokens: {tokens}')
print(f'Token IDs: {token_ids}')
encoding = tokenizer.encode_plus(
sample_txt,
max_length=32,
truncation=True,
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt' # Return PyTorch tensors
)
print(f'Encoding keys: {encoding.keys()}')
print(len(encoding['input_ids'][0]))
print(encoding['input_ids'][0])
print(len(encoding['attention_mask'][0]))
print(encoding['attention_mask'])
print(tokenizer.convert_ids_to_tokens(encoding['input_ids'][0]))
```
上述代码展示了如何使用分词器对样本文本进行分词和编码,并打印出分词结果、编码后的ID、编码字典的键以及编码后的输入ID和注意力掩码。
#### 3.3.3 创建数据集和数据加载器
```python
MAX_LEN = 160
class GPReviewDataset(Dataset):
def __init__(self, reviews, targets, tokenizer, max_len):
self.reviews = reviews
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.reviews)
def __getitem__(self, item):
review = str(self.reviews[item])
target = self.targets[item]
encoding = self.tokenizer.encode_plus(
review,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
return_attention_mask=True,
truncation=True,
pad_to_max_length=True,
return_tensors='pt'
)
return {
'review_text': review,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'targets': torch.tensor(target, dtype=torch.long)
```
0
0
复制全文
相关推荐








