前言
在传统的会话推荐模型中一般采用注意力机制来建模用户的偏好嵌入,但是作者认为用户的偏好是更加复杂的关系,单纯靠交互的物品间的联系不足以充分表达。因此,本文主要通过建模会话图来研究物品之间的联系模式,提出了一种模型框架来协同处理会话图中的序列顺序和潜在顺序,具体来说,本文提出了一个加权注意力图结构和一个读出函数来学习物品和会话的嵌入。
论文链接:https://arxiv.org/pdf/1911.11942v1.pdf
github:https://github.com/RuihongQiu/FGNN
1. FGNN
1.1 Preliminaries
令
G
(
V
,
E
)
G(V,E)
G(V,E) 表示为图,其中
v
∈
V
v \in V
v∈V 表示为节点集,
X
v
\mathbf{X}_v
Xv 表示为节点特征矩阵,
e
∈
E
e\in E
e∈E 表示边的集合。有两种经常被研究的任务分别是节点分类任务和图分类任务,在本文的工作中完成的是图分类任务,因此 FGNN 最终得到的是会话的嵌入表示。对于图分类任务,给定图的集合
{
G
1
,
G
2
,
…
,
G
N
}
⊆
G
\{G_1,G_2,\dots,G_N\}\subseteq\mathcal{G}
{G1,G2,…,GN}⊆G 和对应的标签
{
y
1
,
…
,
y
N
}
∈
Y
\{y_1,\dots,y_N\} \in\mathcal{Y}
{y1,…,yN}∈Y。对于 FGNN 来说期望学习到会话图的嵌入表示
h
G
\mathbf{h}_G
hG 并预测正确的标签,也就是
y
G
=
g
(
h
G
)
y_G=g(\mathbf{h}_G)
yG=g(hG)。在这点上,FGNN 确实与主流的会话预测模型的出发点不同,将整个问题转换成了图分类任务。其实主要的过程就是信息传播过程和将节点表示转换成整张会话图的嵌入表示的读出函数。表示成如下形式:
a
v
(
k
)
=
A
g
g
(
{
h
u
(
k
−
1
)
,
u
∈
N
(
v
)
}
)
,
h
v
(
k
)
=
M
a
p
(
h
v
(
k
−
1
)
,
a
v
(
k
)
)
(1)
\mathbf{a}_v^{(k)}=Agg(\{\mathbf{h}_u^{(k-1)},u\in N(v)\}),\mathbf{h}_v^{(k)}=Map(\mathbf{h}_v^{(k-1)},\mathbf{a}_v^{(k)})\tag{1}
av(k)=Agg({hu(k−1),u∈N(v)}),hv(k)=Map(hv(k−1),av(k))(1)
h G = R e a d o u t ( { h v ( k ) , v ∈ V } ) (2) \mathbf{h}_G=Readout(\{\mathbf{h}_v^{(k)},v \in V\})\tag{2} hG=Readout({hv(k),v∈V})(2)
1.2 Method
1.2.1 Session Graph
建模方式可以参考图1。其中值得注意的是,作者在建模的时候选择将会话建模成有向,有权的图表示。其中方向代表交互物品的先后顺序,权重则代表这条边在会话中出现的频率。
在这里作者简单举出了一个两层的 GNN 传播方式,其中传播方式比较经典的是 GCN 和 GAT,作者在这两者的基础上选择了加权的图注意力机制来完成信息传播任务称为 WGAT。作者提出这种机制的出发点是传统的 GCN 和 GAT 是用于处理无向无权的简单图的,而在 FGNN 中作者选择将会话建模成有向有权的图,因此需要优化后的传播模型来完成任务。
1.2.2 Weighted Graph Attentional Layer
WGAT 的输入是节点特征矩阵
x
=
{
x
0
,
x
1
,
x
2
,
…
,
x
n
−
1
}
,
x
i
∈
R
d
\mathbf{x}=\{\mathbf{x_0},\mathbf{x_1},\mathbf{x_2},\dots,\mathbf{x_{n-1}}\},\mathbf{x_i}\in\mathbb{R}^d
x={x0,x1,x2,…,xn−1},xi∈Rd,其中
n
n
n 和
d
d
d 分别代表会话图中节点的个数和初始特征嵌入的维度。经过 WGAT 可以得到更新后的节点表示
x
′
=
{
x
0
′
,
x
1
′
,
x
2
′
,
…
,
x
n
−
1
′
}
,
x
i
′
∈
R
d
′
\mathbf{x}^{'}=\{\mathbf{x_0}^{'},\mathbf{x_1}^{'},\mathbf{x_2}^{'},\dots,\mathbf{x_{n-1}^{'}}\},\mathbf{x_i}^{'}\in\mathbb{R}^{d^{'}}
x′={x0′,x1′,x2′,…,xn−1′},xi′∈Rd′。对于会话推荐任务来说,首层的的输入就是根据 id 序号得到的 one-hot 向量:
x
i
0
=
E
m
b
e
d
(
v
i
)
(3)
\mathbf{x}_i^0=Embed(v_i)\tag{3}
xi0=Embed(vi)(3)
在第一阶段,根据
x
i
,
x
j
,
w
i
j
\mathbf{x}_i, \mathbf{x}_j,w_{ij}
xi,xj,wij计算注意力系数
e
i
j
e_{ij}
eij,它决定了节点
j
j
j 对节点
i
i
i 的重要性:
e
i
j
=
A
t
t
(
W
x
i
,
W
x
j
,
w
i
j
)
(4)
e_{ij}=Att(\mathbf{W}\mathbf{x}_i,\mathbf{Wx}_j,w_{ij})\tag{4}
eij=Att(Wxi,Wxj,wij)(4)
其中的
A
t
t
(
⋅
)
Att(\cdot)
Att(⋅) 的维度变换为
R
d
×
R
d
×
R
1
→
R
1
\mathbb{R}^d\times\mathbb{R}^d\times\mathbb{R}^1\rightarrow\mathbb{R}^1
Rd×Rd×R1→R1。
W
\mathbf{W}
W 是一个在所有节点上执行线性映射的共享参数矩阵,
w
i
j
w_{ij}
wij 表示节点之间边的权重。实际上,一个节点
i
i
i 的注意可以扩展到每个节点,这是一种特殊情况,就像 STAMP 如何使序列的最后一个节点的注意一样。这里限制关注的范围是一阶邻域内的节点。对于 STAMP 的序列结构注意力处理方法,这里选择建模成图同时存在对于范围的限制是一节邻域内的所有节点。利用会话的固有结构图直接比较不同节点的重要性,softmax 对注意力系数进行归一化处理:
α
i
j
=
s
o
f
t
m
a
x
(
e
i
j
)
=
e
x
p
(
e
i
j
)
∑
k
∈
N
(
i
)
e
x
p
(
e
i
k
)
(5)
\alpha_{ij}=softmax(e_{ij})=\frac{exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)}exp(e_{ik})}\tag{5}
αij=softmax(eij)=∑k∈N(i)exp(eik)exp(eij)(5)
对于
A
t
t
(
⋅
)
Att(\cdot)
Att(⋅) 的选择有很多种,在本文中作者选用 MLP,其中
W
a
t
t
∈
R
2
d
+
1
\mathbf{W}_{att} \in \mathbb{R}^{2d+1}
Watt∈R2d+1。最终注意力归一化系数的计算过程表示如下:
α
i
j
=
e
x
p
(
L
e
a
k
y
R
e
l
u
(
W
a
t
t
[
W
x
i
∣
∣
W
x
j
∣
∣
w
i
j
]
)
)
∑
k
∈
N
(
i
)
e
x
p
(
L
e
a
k
t
R
e
l
u
(
W
a
t
t
[
W
x
i
∣
∣
W
x
j
∣
∣
w
i
j
]
)
)
(6)
\alpha_{ij}=\frac{exp(LeakyRelu(\mathbf{W}_{att}[\mathbf{Wx}_i||\mathbf{Wx}_j||w_{ij}]))}{\sum_{k\in\mathcal{N}_{(i)}}exp(LeaktRelu(\mathbf{W}_{att}[\mathbf{Wx}_i||\mathbf{Wx}_j||w_{ij}]))}\tag{6}
αij=∑k∈N(i)exp(LeaktRelu(Watt[Wxi∣∣Wxj∣∣wij]))exp(LeakyRelu(Watt[Wxi∣∣Wxj∣∣wij]))(6)
其中
⋅
∣
∣
⋅
\cdot||\cdot
⋅∣∣⋅ 代表两向量的级联。基于此最后的节点表示就是归一化的注意力系数
α
i
j
\alpha_{ij}
αij 与特征嵌入
x
j
\mathbf{x}_j
xj 以及转换矩阵
W
\mathbf{W}
W 的乘积:
x
i
′
=
σ
(
∑
j
∈
N
(
i
)
α
i
j
W
x
j
)
(7)
\mathbf{x}_i^{'}=\sigma(\sum_{j\in\mathcal{N}_{(i)}}\alpha_{ij}\mathbf{W}\mathbf{x}_j)\tag{7}
xi′=σ(j∈N(i)∑αijWxj)(7)
其中
σ
(
⋅
)
\sigma(\cdot)
σ(⋅) 代表非线性激活函数,在本文中使用的是 Relu。
并且应用了 GAT 中的多头注意力机制,表示形式为:
x
i
′
=
∣
∣
k
=
1
K
σ
(
∑
i
∈
N
(
i
)
α
i
j
k
W
k
x
j
)
(8)
\mathbf{x}_i^{'}=||_{k=1}^{K}\sigma(\sum_{i\in\mathcal{N}_{(i)}}\alpha_{ij}^{k}\mathbf{W}^k\mathbf{x}_j)\tag{8}
xi′=∣∣k=1Kσ(i∈N(i)∑αijkWkxj)(8)
其中
K
K
K 代表 head 的数量,每一个 head 对应不同的参数,最终
⋅
∣
∣
⋅
\cdot||\cdot
⋅∣∣⋅ 代表级联函数,因此经过式(8)得到的节点特征维度为
x
i
′
∈
R
K
d
′
\mathbf{x}^{'}_{i}\in\mathbb{R}^{Kd^{'}}
xi′∈RKd′。最终经过取平均的操作得到最后的嵌入表示:
x
i
′
=
σ
(
1
K
∑
k
=
1
K
∑
j
∈
N
(
i
)
α
i
j
k
W
k
x
j
)
(9)
\mathbf{x}^{'}_{i}=\sigma(\frac{1}{K}\sum_{k=1}^K\sum_{j\in\mathcal{N}_{(i)}}\alpha_{ij}^k\mathbf{W}^k\mathbf{x}_j)\tag{9}
xi′=σ(K1k=1∑Kj∈N(i)∑αijkWkxj)(9)
当多个 WGAT 层的正向传播完成后,得到所有节点的最终特征向量,即物品特征嵌入。这些嵌入将作为会话嵌入计算阶段的输入。
1.2.3 Readout Function
读出函数的目的是在对 GNN 层进行前向计算后,基于节点特征给出整个图的表示。如上所述,Readout函数需要学习物品转换模式,以避免时间顺序的偏差和对最后一个输入项目的自我注意的不准确性。为了方便起见,一些算法使用简单的排列不变操作,例如,对所有节点特征的平均值、最大值或总和。虽然上述方法简单且完全不违反排列不变性的约束,但是上述操作不能提供所需的会话表示的表达能力。相反,Set2Set是一个图级特征提取器。作者修改这个方法以适应会话图的设置。计算过程如下:
q
t
=
G
R
U
(
q
t
−
1
∗
)
(10)
\mathbf{q}_t=GRU(\mathbf{q}_{t-1}^*)\tag{10}
qt=GRU(qt−1∗)(10)
e
i
,
t
=
f
(
x
i
,
q
t
)
(11)
e_{i,t}=f(\mathbf{x}_i,\mathbf{q}_t)\tag{11}
ei,t=f(xi,qt)(11)
a
i
,
t
=
e
x
p
(
e
i
,
t
)
∑
j
e
x
p
(
e
j
,
t
)
(12)
a_{i,t}=\frac{exp(e_{i,t})}{\sum_jexp(e_{j,t})}\tag{12}
ai,t=∑jexp(ej,t)exp(ei,t)(12)
r
t
=
∑
i
a
i
,
t
x
i
(13)
\mathbf{r}_t=\sum_ia_{i,t}\mathbf{x}_i\tag{13}
rt=i∑ai,txi(13)
q
t
∗
=
q
∣
∣
r
t
(14)
\mathbf{q}^*_t=\mathbf{q}||\mathbf{r}_t\tag{14}
qt∗=q∣∣rt(14)
其中 i i i 是图中节点的索引, q t ∈ R d \mathbf{q}_t \in \mathbb{R}^d qt∈Rd 代表目前所有节点级联表示的查询向量,可以用来计算当前计算节点 x i \mathbf{x}_i xi 对于之前所有节点表示级联后结果的注意力系数 e i , t e_{i,t} ei,t 也就是函数 f ( ⋅ ) f(\cdot) f(⋅) 的功能,之后计算注意力系数的归一化表示 a i , t a_{i,t} ai,t,得到当前节点的信息传播结果 r t ∈ R d \mathbf{r}_t \in \mathbb{R}^d rt∈Rd,将之前所有结点的查询向量和当前节点的信息传播进行拼接得到当前所有结点的级联表示 q t ∗ ∈ R 2 d \mathbf{q}^*_t \in \mathbb{R}^{2d} qt∗∈R2d。
基于会话图的所有节点嵌入,上述过程得到了一个图级嵌入,其中包含一个查询向量 q t \mathbf{q}_t qt 以及语义嵌入向量 r t \mathbf{r}_t rt。查询向量 q t \mathbf{q}_t qt 控制从节点嵌入中读取什么,如果递归地应用 Readout 函数,它实际上可以处理会话图中的所有的节点。因此最后得到的会话图嵌入表示为 q t ∗ ∈ R 2 d \mathbf{q}^*_t \in \mathbb{R}^{2d} qt∗∈R2d。
1.2.4 Recommendation
在得到当前会话图的嵌入表示之后,基于此得到所有候选物品的得分向量
z
^
\hat{\mathbf{z}}
z^:
z
^
=
(
W
o
u
t
q
t
∗
)
T
X
0
(15)
\hat{\mathbf{z}}=(\mathbf{W}_{out}\mathbf{q}_t^*)^T\mathbf{X}^0\tag{15}
z^=(Woutqt∗)TX0(15)
其中
W
o
u
t
∈
R
d
×
2
d
\mathbf{W}_{out}\in\mathbb{R}^{d\times2d}
Wout∈Rd×2d 作为转换矩阵的可学习参数,
X
0
\mathbf{X}^0
X0 代表所有候选物品的根据 id 得到的初始 one-hot 嵌入。基于此可以得到归一化的概率预测向量
y
^
\hat{\mathbf{y}}
y^:
y
^
=
s
o
f
t
m
a
x
(
z
^
)
(16)
\hat{\mathbf{y}}=softmax(\hat{\mathbf{z}})\tag{16}
y^=softmax(z^)(16)
基于此,选择 top-K 个物品完成推荐任务。
1.2.5 Objective Function
对于一个图分类问题,作者将
y
^
\hat{\mathbf{y}}
y^ 和
v
l
a
b
e
l
,
i
{v}_{label,i}
vlabel,i 的 one-hot 编码之间的多类交叉熵损失作为目标函数。对于一个 batch:
L
=
∑
i
=
1
l
o
n
e
−
h
o
t
(
v
l
a
b
e
l
,
i
)
l
o
g
(
y
^
)
(17)
L=\sum_{i=1}^l one-hot(v_{label,i})log(\hat{\mathbf{y}})\tag{17}
L=i=1∑lone−hot(vlabel,i)log(y^)(17)
最后使用时间反向传播(BPTT)算法对整个 FGNN 模型进行训练。
作者将该模型描述成图级别的分类任务,我没看出来原因是什么,和其他的会话推荐模型的计算过程一致。因此我觉得没什么必要,如果我的理解有误,欢迎大家指出。
2. Experiments