2.9 实战演练之预训练模型

目录

1 什么是预训练

2 三种模型架构

2.1 掩码语言模型(主要应用于NLU任务)

工作机制:

典型应用:

优点:

缺点:

2.2 因果语言模型(主要应用于LMM)

工作机制:

典型应用:

优点:

缺点:

2.3 序列到序列模型(以Transformer架构为例)

工作机制:

典型应用:

优点:

缺点:

2.4 拓展:前缀语言模型(以GLM为例)

工作机制:

与因果语言模型的区别:

典型应用:

优点:

例子:

总结:

3 代码实战演练——掩码语言模型

3.1 导包

3.2 加载数据集

3.3 数据集处理

3.4 创建模型

3.5 配置训练参数

3.6 创建训练器

3.7 模型训练

3.8 模型推理

4 代码实战演练——因果语言模型

4.1 导包

4.2 加载数据集

4.3 数据集处理

4.4 创建模型

4.5 配置训练参数

4.6 创建训练器

4.7 模型训练

4.8 模型推理


1 什么是预训练

预训练是机器学习中的一种技术,通常用于自然语言处理(NLP)、计算机视觉等领域。它指的是在大型通用数据集上训练一个模型,使其学习到广泛的知识和模式,然后再通过**微调(fine-tuning)**来适应具体任务。预训练的目标是让模型在开始特定任务的学习之前,已经具备对数据的基本理解,从而减少对大量任务特定标注数据的依赖。

2 三种模型架构

掩码语言模型(Masked Language Model)、因果语言模型(Causal Language Model)、编码解码语言模型(Encoder-Decoder Model)是三种在自然语言处理中常用的模型架构,它们有不同的训练方式和应用场景。前两种最为常见,以下是对它们的详细介绍:

2.1 掩码语言模型(主要应用于NLU任务)

掩码语言模型是一种双向的语言模型,常用于模型在大量文本上进行预训练,使得它可以理解上下文中的词汇间关系。它的主要特点是训练时会将输入文本中的某些词进行随机掩码,让模型根据上下文预测这些被掩盖的词。

工作机制:

  • 输入一个完整的句子,然后随机选择某些词进行掩码(例如,用 [MASK] 代替)。
  • 模型根据上下文预测这些掩码词。
  • 这种双向的设计允许模型同时利用句子左侧和右侧的上下文信息。

典型应用:

  • BERT(Bidirectional Encoder Representations from Transformers)是最具代表性的掩码语言模型。在 BERT 训练时,15% 的输入词被随机掩码,模型根据其他未掩码的词预测这些词。

优点:

  • 上下文丰富:由于模型能从句子的两边学习信息,模型可以更好地理解句子中的语义和结构。
  • 广泛应用于多种下游任务:如文本分类、命名实体识别、问答系统等。

缺点:

  • 仅适合预训练:因为训练时会对词进行掩码,在实际生成任务中无法直接使用。

2.2 因果语言模型(主要应用于LMM)

现在我要将1~14输入给模型,我在输入1的时候,其标签就是2;输入2的时候,其标签就是3,以此类推,直到输入14,发现下一个是eos,结束。

因果语言模型是一种单向语言模型,按顺序生成文本,即每个词只能根据它前面的词来预测下一个词。这种模型通常用于文本生成任务,因为它能够一步一步地预测句子的下一个词。

工作机制:

  • 只考虑输入序列的左侧上下文(即已经生成的部分),不使用右侧未来的信息。
  • 在训练时,模型通过先前生成的词预测下一个词。

典型应用:

  • GPT(Generative Pre-trained Transformer)系列模型是因果语言模型的代表。GPT 只考虑输入的前面部分,在生成任务(如对话、文本生成)中逐步生成下一个词。

优点:

  • 适合生成任务:可以逐步生成新内容,特别适用于对话系统、机器翻译、自动写作等生成式任务。
  • 自然生成过程:生成时没有掩码,只是基于已有的上下文生成新内容。

缺点:

  • 缺乏双向上下文信息:因为只考虑左侧上下文,模型可能缺乏对全局信息的理解能力。

2.3 序列到序列模型(以Transformer架构为例)

编码解码模型是一种由两个部分组成的架构:一个编码器(Encoder)负责理解输入,一个解码器(Decoder)负责生成输出。它常用于需要输入和输出不同类型信息的任务,如机器翻译、摘要生成等。

工作机制:

  • 编码器接收输入序列,将其转换为一个上下文向量或隐藏状态序列,提取出输入中的语义信息。
  • 解码器基于编码器生成的隐藏状态,逐步生成目标序列的每一个词。解码器可以是自回归的,即每一步的生成依赖于前一步的结果。

典型应用:

  • Transformer(Vaswani et al., 2017)是编码解码模型的经典代表。它在机器翻译、文本摘要、对话生成等任务中取得了显著成功。
  • T5(Text-To-Text Transfer Transformer)是一个编码解码架构的预训练模型,它将所有任务(如翻译、问答、分类)都转换为文本到文本的格式。

优点:

  • 输入输出结构灵活:可以处理不同长度的输入和输出序列,非常适合翻译、问答、对话等任务。
  • 强大的生成能力:编码器和解码器结合的设计让模型既能理解复杂的输入,又能生成符合上下文的输出。

缺点:

  • 训练复杂度高:由于需要同时训练编码器和解码器,计算资源的需求较大。
  • 延迟:解码过程通常是逐步的,生成速度可能较慢。

2.4 拓展:前缀语言模型(以GLM为例)

文本生成系列之前缀语言模型 - 知乎 (zhihu.com)

前缀语言模型(Prefix Language Model) 是一种生成式模型,常用于文本生成任务。它类似于因果语言模型(Causal Language Model),但区别在于前缀语言模型在生成过程中使用了一个“前缀”作为条件,生成后续的文本。

具体来说,前缀语言模型先给定一个部分已知的文本(即“前缀”),并基于这个前缀生成接下来的文本。这种模型可以被认为是从已知前缀开始进行预测和生成。

工作机制:

  • 前缀语言模型从一个输入序列(前缀)开始,预测下一个词,然后逐步生成接下来的词。
  • 它在生成时将前缀作为输入序列的条件,通过这种方式,模型能够生成与前缀相匹配的、上下文一致的文本。

与因果语言模型的区别:

  • 前缀语言模型通常是给定了一个明确的“起始前缀”,如问题、短语或部分句子,然后模型生成后续文本。可以将其看作是基于已有条件生成的模型。
  • 因果语言模型则是每个词仅依赖于前面出现的词,逐步生成下一个词,但在开始时并不一定有明确的前缀。

典型应用:

前缀语言模型广泛应用于需要条件生成的任务,比如:

  • 对话系统:给定一个问题作为前缀,生成合适的回答。
  • 文本补全:输入一个句子的开头部分,生成完整的句子或段落。
  • 机器翻译:输入句子作为前缀,生成目标语言的译文。

优点:

  • 上下文依赖生成:生成的文本更自然,因为它会根据前缀的内容来生成,保证上下文连贯性。
  • 灵活性:可以在前缀的基础上生成符合语义和上下文逻辑的文本,适用于各种条件生成任务。

例子:

假设你输入一个前缀“今天的天气”,前缀语言模型可能会生成:“今天的天气非常晴朗,适合出去散步。”

总结:

前缀语言模型是一种有条件的生成模型,能够根据给定的前缀来生成符合上下文的文本。它在很多自然语言生成任务中有着广泛的应用,尤其是在对话生成、文本补全和机器翻译等领域。

3 代码实战演练——掩码语言模型

3.1 导包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer

3.2 加载数据集

3.3 数据集处理

这块我们仅仅对数据进行tokenizer处理即可,不需要设置mask,这个会在DataCollator处完成。

设置DataLoader: 

mlm=True:表示做掩码语言模型;

mlm_probability=0.15:掩码概率。

from torch.utils.data import DataLoader

dl = DataLoader(tokenized_ds, batch_size=2, collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15))

next(enumerate(dl))

可以观察到label中大多数位置都被设置为-100,一少部分被选中为mask的label得以保留(设置为 -100 的标签将被忽略,不会对损失计算产生影响)

labels:

input_ids:

对应到input_ids中,103的部位其实就是被选中作为mask的部位

3.4 创建模型

3.5 配置训练参数

3.6 创建训练器

3.7 模型训练

3.8 模型推理

[[{'score': 0.9948391318321228,
   'token': 1920,
   'token_str': '大',
   'sequence': "[CLS] 西 安 交 通 大 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0021944143809378147,
   'token': 2110,
   'token_str': '学',
   'sequence': "[CLS] 西 安 交 通 学 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0005357894115149975,
   'token': 2339,
   'token_str': '工',
   'sequence': "[CLS] 西 安 交 通 工 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0003660529328044504,
   'token': 7770,
   'token_str': '高',
   'sequence': "[CLS] 西 安 交 通 高 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0003039983275812119,
   'token': 4906,
   'token_str': '科',
   'sequence': "[CLS] 西 安 交 通 科 [MASK] 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"}],
 [{'score': 0.9971689581871033,
   'token': 2110,
   'token_str': '学',
   'sequence': "[CLS] 西 安 交 通 [MASK] 学 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0011966773308813572,
   'token': 7368,
   'token_str': '院',
   'sequence': "[CLS] 西 安 交 通 [MASK] 院 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0008860519737936556,
   'token': 1920,
   'token_str': '大',
   'sequence': "[CLS] 西 安 交 通 [MASK] 大 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 0.0001224309962708503,
   'token': 3318,
   'token_str': '术',
   'sequence': "[CLS] 西 安 交 通 [MASK] 术 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"},
  {'score': 8.154550596373156e-05,
   'token': 4289,
   'token_str': '物',
   'sequence': "[CLS] 西 安 交 通 [MASK] 物 博 物 馆 ( xi'an jiaotong university museum ) 是 一 座 位 于 西 安 交 通 大 学 的 博 物 馆 [SEP]"}]]

[[{'score': 0.09348516166210175,
   'token': 2031,
   'token_str': '娱',
   'sequence': '[CLS] 下 面 是 一 则 娱 [MASK] 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.043486643582582474,
   'token': 7028,
   'token_str': '重',
   'sequence': '[CLS] 下 面 是 一 则 重 [MASK] 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.04063192754983902,
   'token': 2207,
   'token_str': '小',
   'sequence': '[CLS] 下 面 是 一 则 小 [MASK] 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.03394673392176628,
   'token': 3173,
   'token_str': '新',
   'sequence': '[CLS] 下 面 是 一 则 新 [MASK] 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.03209644556045532,
   'token': 5381,
   'token_str': '网',
   'sequence': '[CLS] 下 面 是 一 则 网 [MASK] 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'}],
 [{'score': 0.07695978879928589,
   'token': 4829,
   'token_str': '磅',
   'sequence': '[CLS] 下 面 是 一 则 [MASK] 磅 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.06807652860879898,
   'token': 6380,
   'token_str': '讯',
   'sequence': '[CLS] 下 面 是 一 则 [MASK] 讯 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.06187007576227188,
   'token': 7319,
   'token_str': '闻',
   'sequence': '[CLS] 下 面 是 一 则 [MASK] 闻 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.040275026112794876,
   'token': 2141,
   'token_str': '实',
   'sequence': '[CLS] 下 面 是 一 则 [MASK] 实 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'},
  {'score': 0.033520422875881195,
   'token': 2767,
   'token_str': '戏',
   'sequence': '[CLS] 下 面 是 一 则 [MASK] 戏 新 闻 。 小 编 报 道 , 近 日 , 游 戏 产 业 发 展 的 非 常 好 ! [SEP]'}]]

4 代码实战演练——因果语言模型

4.1 导包

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, BloomForCausalLM

4.2 加载数据集

4.3 数据集处理

需要在句子末尾加上一个eos_token

因为我们这块做的是因果语言模型,不再是掩码语言模型,因此要设置mlm=False 

开始的3是pad_id(这个是在开头padding的);

结束的2就是eos_token_id

(0,
 {'input_ids': tensor([[    3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3,     3,     3,
              3,     3,     3,     3,     3,     3,     3,     3, 13110, 34800,
          13535,   916, 33156,    10,   256,   576,   387,   479,   681,  5453,
          10955,   915, 24124,  5317, 13110,  6573, 20757, 13535,   355,  5358,
           1490,   583, 28056,  1407,  3855,   671,  6113,   189,  6732,  4302,
           9488,  3434,  6900,  1322,   355, 37336,  9825,  4608, 13461,  1359,
           5358,   355,  5317, 13110, 34800,  4433,  7189, 25722, 29747, 13110,
           1498, 12047,  6347, 23563,  2139,  2066,   420, 29288,    25,    15,
           7635, 39288,   355,  1484,  5835,  6272,    23,    15,  4180, 39288,
            355,  5358,  4516, 11621,    23,    15, 10641,  4887,  1712,   420,
           2450, 31163,  8085, 11621,  5358,   553,  9888,  2731, 21335,  5358,
            553,  9876, 14011,  4434,  5358,   553, 21484,  4514, 17170, 25871,
           8085,  5358,   553, 17489,  6945, 11097, 13535,   641, 33623,  1484,
           5835,  1689,  2063,  5358,   569,  5835,   671, 16225,  3422,   189,
             13, 23158, 27894, 33227,  1022, 11396,  3347,  1813,  1504,  6566,
           1813,   355,  9155,  8633,  1504,  2063,  1813,   189,    13, 23158,
            813,  7817,  5358,     2],
         [  586, 11835,   739,   355,  8417,  2300,   916, 16892,  5077,   580,
          15549,  1434,   996,   307,   387,   355,  1997,  8236,   775,  8417,
           2152,  6496, 12176,  3728,  7692,  1503,  1528,  1014, 17887, 18001,
           3250,   355,  6250,  3896,  2670, 20882, 16080, 14005,  3993,  1503,
          12295,  8930,  3250,   420,  4208,  4326,  5367,  8417,  2152,  4632,
          37591,   355,  1379,  8417,  2300, 32824,  3250,  9015, 18945, 16714,
           5908, 13551, 30330, 23756,  2152, 15976,   657,  5923,  8417,   586,
          16080,  1528,  8941,  5917, 14035,  5895, 14092,  4353, 29445,   355,
           8790, 21595,  1450, 15911, 31116, 10345, 29940, 10874,  1125,  4829,
          16080,  7093, 22939,   737,   262,   387,    72,   272,   831,   455,
             72,   915,  8860, 20725,  1934,  1084,  5478,   420,  4415,  8417,
          26263, 12726,   553, 10875, 34820,   355,  1266,  5498,   586, 14907,
          32795, 11835,   904, 19934,  1528, 19531, 20517,  1349, 19472, 28879,
            671,  8417, 26263, 17020,  5963, 22388, 11900, 12669, 13240,  1042,
           9783,   355,  7242,   714,  1806,   775,  6500,   355, 11526, 10185,
           1293,  1665, 15984,  7092,  1617,  8417,  2300,   420, 19972, 25622,
          10875, 17500, 26523,  2391,  8417,  2300,   355,  6751,  2836, 13539,
           8247,   373, 30201,  5498,   420,  8417,  2300,   586,  8523, 19358,
           1298, 12176, 30939, 10739,   964,  4318, 10875,   420, 11900, 16080,
            904,  9783,   355, 22464,  9658,   355,  8417,  2300, 13561,  2054,
           4983,  4829, 30800,  7262,   420, 11394,  8417, 35143, 11937, 15682,
           8417,  2300,  7283, 10875,   355,  1016,  4179,  5039, 14027, 26215,
          26835,   671, 15095,   189,  1165, 15095, 11900,  6184,  1125,  3244,
           3687,   622,  8785,  1121,   891, 13765,   671, 10199,   189,    13,
            210,  6940,  3728,  9552,  6082,  8417,  2300,   916,  4375,   714,
           3679,  1806, 10567,   915,   189,    13,   210, 16131, 11835,  8417,
           2300,   189,    13, 24075,  3728,  8417,  2300,   916,  4829,  3687,
           9000, 27689,  8417,  2300,  1300, 11243,  2062, 28431, 27689, 11835,
           8417,  2300, 13224,  4829,  3687, 15964,   915,   189,    13, 27340,
          11835,  8417,  2300,   189,    13, 27340,  3982,  3728,  8417,  2300,
            916,  4375,  2450,  3272,  1234, 19083,   553,  3512,  3121,  1728,
            641,  3092,  2113,  7843,   915,   189,    13,   210, 15402,  8417,
           2300,   189,    13, 10057,   108, 12693, 14624, 29379,  4719,  6533,
            739,   916, 16148, 10981, 21350,  9067,  1203,  8931,  1258, 11835,
           4719,  6533,   739, 20393,   189,    13,   210, 19546,  1517, 11835,
           8417,  2300,   189,    13,   210, 23928,   168,   117,   245,  6279,
            114,   240,   170,   100,   124,   168,   117,   228,  6279,   100,
            124,   168,   117,   228,   171,   238,   224, 41356,   236, 24175,
          11082, 10981, 21350,  9067]]), 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 13110, 34800,
          13535,   916, 33156,    10,   256,   576,   387,   479,   681,  5453,
          10955,   915, 24124,  5317, 13110,  6573, 20757, 13535,   355,  5358,
           1490,   583, 28056,  1407,  3855,   671,  6113,   189,  6732,  4302,
           9488,  3434,  6900,  1322,   355, 37336,  9825,  4608, 13461,  1359,
           5358,   355,  5317, 13110, 34800,  4433,  7189, 25722, 29747, 13110,
           1498, 12047,  6347, 23563,  2139,  2066,   420, 29288,    25,    15,
           7635, 39288,   355,  1484,  5835,  6272,    23,    15,  4180, 39288,
            355,  5358,  4516, 11621,    23,    15, 10641,  4887,  1712,   420,
           2450, 31163,  8085, 11621,  5358,   553,  9888,  2731, 21335,  5358,
            553,  9876, 14011,  4434,  5358,   553, 21484,  4514, 17170, 25871,
           8085,  5358,   553, 17489,  6945, 11097, 13535,   641, 33623,  1484,
           5835,  1689,  2063,  5358,   569,  5835,   671, 16225,  3422,   189,
             13, 23158, 27894, 33227,  1022, 11396,  3347,  1813,  1504,  6566,
           1813,   355,  9155,  8633,  1504,  2063,  1813,   189,    13, 23158,
            813,  7817,  5358,     2],
         [  586, 11835,   739,   355,  8417,  2300,   916, 16892,  5077,   580,
          15549,  1434,   996,   307,   387,   355,  1997,  8236,   775,  8417,
           2152,  6496, 12176,  3728,  7692,  1503,  1528,  1014, 17887, 18001,
           3250,   355,  6250,  3896,  2670, 20882, 16080, 14005,  3993,  1503,
          12295,  8930,  3250,   420,  4208,  4326,  5367,  8417,  2152,  4632,
          37591,   355,  1379,  8417,  2300, 32824,  3250,  9015, 18945, 16714,
           5908, 13551, 30330, 23756,  2152, 15976,   657,  5923,  8417,   586,
          16080,  1528,  8941,  5917, 14035,  5895, 14092,  4353, 29445,   355,
           8790, 21595,  1450, 15911, 31116, 10345, 29940, 10874,  1125,  4829,
          16080,  7093, 22939,   737,   262,   387,    72,   272,   831,   455,
             72,   915,  8860, 20725,  1934,  1084,  5478,   420,  4415,  8417,
          26263, 12726,   553, 10875, 34820,   355,  1266,  5498,   586, 14907,
          32795, 11835,   904, 19934,  1528, 19531, 20517,  1349, 19472, 28879,
            671,  8417, 26263, 17020,  5963, 22388, 11900, 12669, 13240,  1042,
           9783,   355,  7242,   714,  1806,   775,  6500,   355, 11526, 10185,
           1293,  1665, 15984,  7092,  1617,  8417,  2300,   420, 19972, 25622,
          10875, 17500, 26523,  2391,  8417,  2300,   355,  6751,  2836, 13539,
           8247,   373, 30201,  5498,   420,  8417,  2300,   586,  8523, 19358,
           1298, 12176, 30939, 10739,   964,  4318, 10875,   420, 11900, 16080,
            904,  9783,   355, 22464,  9658,   355,  8417,  2300, 13561,  2054,
           4983,  4829, 30800,  7262,   420, 11394,  8417, 35143, 11937, 15682,
           8417,  2300,  7283, 10875,   355,  1016,  4179,  5039, 14027, 26215,
          26835,   671, 15095,   189,  1165, 15095, 11900,  6184,  1125,  3244,
           3687,   622,  8785,  1121,   891, 13765,   671, 10199,   189,    13,
            210,  6940,  3728,  9552,  6082,  8417,  2300,   916,  4375,   714,
           3679,  1806, 10567,   915,   189,    13,   210, 16131, 11835,  8417,
           2300,   189,    13, 24075,  3728,  8417,  2300,   916,  4829,  3687,
           9000, 27689,  8417,  2300,  1300, 11243,  2062, 28431, 27689, 11835,
           8417,  2300, 13224,  4829,  3687, 15964,   915,   189,    13, 27340,
          11835,  8417,  2300,   189,    13, 27340,  3982,  3728,  8417,  2300,
            916,  4375,  2450,  3272,  1234, 19083,   553,  3512,  3121,  1728,
            641,  3092,  2113,  7843,   915,   189,    13,   210, 15402,  8417,
           2300,   189,    13, 10057,   108, 12693, 14624, 29379,  4719,  6533,
            739,   916, 16148, 10981, 21350,  9067,  1203,  8931,  1258, 11835,
           4719,  6533,   739, 20393,   189,    13,   210, 19546,  1517, 11835,
           8417,  2300,   189,    13,   210, 23928,   168,   117,   245,  6279,
            114,   240,   170,   100,   124,   168,   117,   228,  6279,   100,
            124,   168,   117,   228,   171,   238,   224, 41356,   236, 24175,
          11082, 10981, 21350,  9067]])})

那么错位计算loss在哪里体现的呢?这个其实是在模型中自己完成的,可以观察BloomForCausalLM源码中forward部分

  • shift_logits: 将模型的预测结果 lm_logits 的最后一个时间步(token)去掉。这是因为每个 token 的预测是基于前面的 token,因此需要移位。
  • shift_labels: 将标签 labels 的第一个 token 去掉。这样,shift_labels 中的每个 token 对应于其后的预测 token。

4.4 创建模型

4.5 配置训练参数

4.6 创建训练器

4.7 模型训练

4.8 模型推理

在使用 do_sample=True 参数时,意味着生成文本时会使用采样策略来决定下一个单词。这与使用贪婪解码或束搜索等确定性方法不同。启用采样时,模型不会总是选择概率最高的下一个词,而是会根据其概率分布随机选择一个词。这使得生成的文本更加多样化和自然。 

如果设置 do_sample=False,那么就会使用贪心算法(greedy decoding)来生成文本,即每次选择模型输出概率最大的 token 作为下一个 token,这种方法生成的文本可能会比较单一和呆板。

[{'generated_text': "西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安交通大学西安学院西楼三期大楼的二层混合式博物馆。该建筑占地约2000平方米,于2012年落成开馆,是西安交通大学博物馆的一个组成部分,由西安交通大学及陕西省西安市设计与艺术设计研究院联合发起筹建。该建筑是西安交通大学学生宿舍(西安理工大学建筑学院宿舍)及学生食堂(西安理工大学建筑学院食堂)的一部份,建成后将极大地方便西安交通大学所有学生的住宿和出行。\n博物馆馆舍\n博物馆位于西安交通大学西安学院西楼3号楼,由陕西西安设计院"}]

[{'generated_text': '下面是一则游戏新闻。小编报道,近日,游戏产业发展的非常之盛。而随着这几年游戏的高速发展,玩家对于游戏的了解越来越强烈,以至于在游戏行业里,“烧钱”或“低成本”的定义也逐渐淡出了人们的视线。\n虽然,这是指游戏产业,但这仅仅是以游戏厂商,而不是游戏开发商为准。\n娱乐产业\n近期,随着游戏产业的发展,大量有潜力的电子游戏公司,如电子游戏工作室Game Factory已经或正在开发大型商业化游戏,并且已经确定了。这些游戏产业的公司,有的是以研发游戏,而有一部分,如N'}] 

以上这些模型是自然语言处理(NLP)任务的基础。理解它们的工作原理和设计思想有助于之后章节更好地理解大模型的架构和功能。 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值