Pytorch怎么实现Transformer

蜗牛 互联网技术资讯 2022-05-16 175 0

本篇内容主要讲解“Pytorch怎么实现Transformer”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Pytorch怎么实现Transformer”吧!

一、构造数据

1.1 句子长度

# 关于word embedding,以序列建模为例
# 输入句子有两个,第一个长度为2,第二个长度为4
src_len = torch.tensor([2, 4]).to(torch.int32)
# 目标句子有两个。第一个长度为4, 第二个长度为3
tgt_len = torch.tensor([4, 3]).to(torch.int32)
print(src_len)
print(tgt_len)

输入句子(src_len)有两个,第一个长度为2,第二个长度为4
目标句子(tgt_len)有两个。第一个长度为4, 第二个长度为3

Pytorch怎么实现Transformer  pytorch 第1张

1.2 生成句子

用随机数生成句子,用0填充空白位置,保持所有句子长度一致

src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)

src_seq为输入的两个句子,tgt_seq为输出的两个句子。
为什么句子是数字?在做中英文翻译时,每个中文或英文对应的也是一个数字,只有这样才便于处理。

Pytorch怎么实现Transformer  pytorch 第2张

1.3 生成字典

在该字典中,总共有8个字(行),每个字对应8维向量(做了简化了的)。注意在实际应用中,应当有几十万个字,每个字可能有512个维度。

# 构造word embedding
src_embedding_table = nn.Embedding(9, model_dim)
tgt_embedding_table = nn.Embedding(9, model_dim)
# 输入单词的字典
print(src_embedding_table)
# 目标单词的字典
print(tgt_embedding_table)

字典中,需要留一个维度给class token,故是9行。

Pytorch怎么实现Transformer  pytorch 第3张

1.4 得到向量化的句子

通过字典取出1.2中得到的句子

# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

Pytorch怎么实现Transformer  pytorch 第4张

该阶段总程序

import torch
# 句子长度
src_len = torch.tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3]).to(torch.int32)
# 构造句子,用0填充空白处
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
# 构造字典
src_embedding_table = nn.Embedding(9, 8)
tgt_embedding_table = nn.Embedding(9, 8)
# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

二、位置编码

位置编码是transformer的一个重点,通过加入transformer位置编码,代替了传统RNN的时序信息,增强了模型的并发度。位置编码的公式如下:(其中pos代表行,i代表列)

Pytorch怎么实现Transformer  pytorch 第5张

2.1 计算括号内的值

# 得到分子pos的值
pos_mat = torch.arange(4).reshape((-1, 1))
# 得到分母值
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8)
print(pos_mat)
print(i_mat)

Pytorch怎么实现Transformer  pytorch 第6张

2.2 得到位置编码

# 初始化位置编码矩阵
pe_embedding_table = torch.zeros(4, 8)
# 得到偶数行位置编码
pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat)
# 得到奇数行位置编码
pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat)
pe_embedding = nn.Embedding(4, 8)
# 设置位置编码不可更新参数
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight)

Pytorch怎么实现Transformer  pytorch 第7张

三、多头注意力

3.1 self mask

有些位置是空白用0填充的,训练时不希望被这些位置所影响,那么就需要用到self mask。self mask的原理是令这些位置的值为无穷小,经过softmax后,这些值会变为0,不会再影响结果。

3.1.1 得到有效位置矩阵

# 得到有效位置矩阵
vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix)

Pytorch怎么实现Transformer  pytorch 第8张

3.1.2 得到无效位置矩阵

invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)

True代表需要对该位置mask

Pytorch怎么实现Transformer  pytorch 第9张

3.1.3 得到mask矩阵
用极小数填充需要被mask的位置

# 初始化mask矩阵
score = torch.randn(2, max(src_len), max(src_len))
# 用极小数填充
mask_score = score.masked_fill(mask_encoder_self_attention, -1e9)
print(mask_score)

Pytorch怎么实现Transformer  pytorch 第10张

算其softmat

mask_score_softmax = F.softmax(mask_score)
print(mask_score_softmax)

可以看到,已经达到预期效果

Pytorch怎么实现Transformer  pytorch 第11张

到此,相信大家对“Pytorch怎么实现Transformer”有了更深的了解,不妨来实际操作一番吧!这里是蜗牛博客网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:niceseo99@gmail.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

评论

有免费节点资源,我们会通知你!加入纸飞机订阅群

×
天气预报查看日历分享网页手机扫码留言评论Telegram