异构图 Link 猜测 理论与DGL 源码实战
书接上文,在上文 重磅好文透彻了解,异构图上 Node 分类理论与DGL源码实战 中,咱们讲了 异构图节点分类回归 使命,而在曾经的系列文章中,咱们也陆续介绍了 同构图上的节点分类回归使命、边分类回归使命以及链接猜测 使命。接着曾经的写作印记,这一篇就该是 异构图上链接猜测使命 了,一同来看看吧 ~ go go go
(1) 异构图上链接猜测根底了解
链接猜测,顾名思义,便是 图中边是否存在 的猜测,本质上是把建模成二分类使命,来猜测边存在的概率。可是 实际存在的边与随机采样的边 构建成了正负样本,外界不需求输入标签,学习的是 图结构本身 的信息,这儿咱们把归结为 无监督机器学习 。在 GraphSage与DGL完成同构图 Link 猜测,通俗易懂好文强推 中,咱们具体介绍了 同构图上的链接猜测 ,把类推到异构图上即可。在异构图上进行链接猜测,咱们运用考虑边两边的节点的Embeding信息,依据其 类似性与相关性等因素 ,来对 边的存在与否 进行判别。
这儿 需求留意 的是:虽然是用的 边两边的2个节点 的信息,可是从 练习多次序 来看,依据 咱们曾经文章 介绍的知识 来了解,这2个节点也是 均融合的周围节点的部分结构与大局性质 的并结合图上空间结构做出的判别。
从 GraphSage与DGL完成同构图 Link 猜测,通俗易懂好文强推 中,咱们也了解到 图上链接猜测属于 无监督机器学习,这和上一篇文章介绍的异构图上节点分类回归猜测使命的不同十分类似,不同只是是在咱们需求对链接猜测进行 负边的采样。留意这儿是 边采样, 而上文用的是节点采样,接口是不一样的,同时这两个使命的 丢失与猜测打分 函数也是不同的。
上文 咱们已经说过 链接猜测 是无监督机器学习,外界不用输入标签,模型学习的其实是依据图的本身结构与数据特性来判别边是否存在的机器学习使命。依据此,咱们应该了解: 既然学习的是图上的某边是否存在,则咱们只是 建图的时分供给各类节点的联络来建图 即可,能够依据用户前史行为日志来构建异构图的边,例如用户购买了某件产品就有用户-》购买-〉产品的联络存在,就能够构建一条边。
依据实际存在的节点联络建边构成正样本,而图上随机采样的边组成负样本,依据 距离比较近的节点特性也类似的 同源偏好假定 来构建丢失进行模型练习 ,能够说是 图上链接(联络)猜测的精华 了。而链接猜测也是现在的许多互联网大厂 运用的最多的一种 建模办法 ,不需求显示的 标签 就能够学习到 需求 的 各类节点 的 Embeding, 十分 nice !!!
曾经的文章对 同构/异构图 的各种机器学习使命均进行了 具体的论述,本文这儿就不在继续赘述了。
感兴趣的同学能够去 作者大众号 上去 阅览前史文章。这儿咱们直接开端本节 依据DGL和RGCN完成的异构图上 链接猜测 机器学习使命的代码介绍吧~
(2) 代码韶光
为了提高文章的 可读性与下降了解难度 ,也坚持每一篇 文章的 相对对立性 ,让读者从任何一篇文章进来都是一篇完整的文章,本文和上文重复的代码,这儿依然会 赘述 着进行介绍。阅览过上一篇文章的同学能够自行越过哈,下面,就让咱们开端coding 吧~
开篇先吼一喉咙 , talk is cheap , show me the code !!!
本文的代码讲的是 依据DGL和RGCN完成的异构图上链接猜测 使命,整个源码流程是一个 小型的工业可用的工程 ,依据dgl完成,觉得有用赶忙 保藏转发 吧~
(2.1) 数据预备 (和上文相同)
咱们 假定 能够输入类似于这样的数据, 其间每2列对应这一种联络,例如 用户2352193 购买了产品CEEC9EBF7,用户用了IP 174.74.201.9登录了账号,用户用IP 174.74.201.9 购买了产品 CEEC9EBF7, 终究的 链接使命猜测是猜测用户的购买志愿,用户到该产品之间,是否会有边存在
更 惯例的一种用法 是,依据无监督的练习,得到图上 各个节点的 Embeding ,这是十分有价值的中间数据产出,能够为咱们其他的机器学习使命供给有力辅佐。
咱们能够把这样一份数据存入 source_data.csv
文件中,用 pandas 接口把数据读入:
graph_features_pdf = pd.read_csv('./source_data.csv')
由于关于异构图模型,节点和边的类型均有多种,为了处理便利,咱们能够把各种类型的节点进行编码,再到后期对其进行解码,对 pandas 的 dataframe
数据结构的编解码,咱们能够运用下面的代码:
@ 欢迎重视微信大众号:算法全栈之路
#编码办法
def encode_map(input_array):
p_map={}
length=len(input_array)
for index, ele in zip(range(length),input_array):
# print(ele,index)
p_map[str(ele)] = index
return p_map
#解码办法
def decode_map(encode_map):
de_map={}
for k,v in encode_map.items():
# index,ele
de_map[v]=k
return de_map
然后用其间的各列node 进行 编码
@ 欢迎重视微信大众号:算法全栈之路
userid_encode_map=encode_map(set(graph_features_pdf['user_id'].values))
# 解码map
userid_decode_map=decode_map(userid_encode_map)
graph_features_pdf['user_id_encoded'] = graph_features_pdf['user_id'].apply(lambda e: userid_encode_map.get(str(e),-1))
# print unique值的个数
userid_count=len(set(graph_features_pdf['user_id_encoded'].values))
print(userid_count)
# user login ip
u_e_ip_src = final_graph_features_pdf['user_id_encoded'].values
u_e_ip_dst = final_graph_features_pdf['ip_encoded'].values
u_e_ip_count = len(u_e_ip_dst)
print("u_e_ip_count", u_e_ip_count)
# user buy item
u_e_item_src = final_graph_features_pdf['user_id_encoded'].values
u_e_item_dst = final_graph_features_pdf['item_id_encoded'].values
u_e_item_count = len(u_e_item_dst)
print("u_e_item_count", u_e_item_count)
这儿只是以 用户节点编码 为例,itemId和 IP 同理编解码即可。留意: 这儿的 u_e_ip_count,u_e_item_count 在下文有用到。
终究咱们能够把图数据保存,供今后的异构图代码 demo运用。
`@ 欢迎重视微信大众号:算法全栈之路
final_graph_pdf=graph_features_pdf[[‘user_id_encoded’,’ip_encoded’,’item_id_encoded’,’label’]].sort_values(by=’user_id_encoded’, ascending=True) final_graph_pdf.to_csv(‘result_label.csv’,index=False)`
依据此,异构图的根底预备数据就完毕了,下面开端正式的coding了。
(2.2) 导包 (和上文相同)
老规矩,先导包,依据DGL和RGCN完成的异构图上链接猜测使命只需求这些包就能够了。
@ 欢迎重视微信大众号:算法全栈之路
import argparse
import torch
import torch.nn as nn
import dgl
import torch.optim as optim
from dgl.dataloading import MultiLayerFullNeighborSampler, EdgeDataLoader
from dgl.dataloading.negative_sampler import Uniform
import numpy as np
import pandas as pd
import itertools
import os
import tqdm
from dgl import save_graphs, load_graphs
import dgl.function as fn
import torch
import dgl
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv, SAGEConv, HeteroGraphConv
from dgl.utils import expand_as_pair
import tqdm
from collections import defaultdict
import torch as th
import dgl.nn as dglnn
from dgl.data.utils import makedirs, save_info, load_info
from sklearn.metrics import roc_auc_score
import gc
gc.collect()
推荐一个东西,tqdm 很好用哦,结合 dataloading接口, 能够看到模型练习以及数据处理履行的进展,赶忙用起来吧~
各种模型东西无所谓分类,能处理问题的便是好东西,混用又有何不可呢? 有用就行!
(2.3) 构图
数据有了,接下来便是构图了,咱们构建的是包含三种节点的异构图。
@ 欢迎重视微信大众号:算法全栈之路
# user 登录 ip
u_e_ip_src = final_graph_pdf['user_id_encoded'].values
u_e_ip_dst = final_graph_pdf['ip_encoded'].values
# user 购买 item
u_e_item_src = final_graph_pdf['user_id_encoded'].values
u_e_item_dst = final_graph_pdf['item_id_encoded'].values
# item和ip 共同呈现
ip_e_item_src = final_graph_pdf['ip_encoded'].values
ip_e_item_dst = final_graph_pdf['item_id_encoded'].values
# user 购买 label
user_node_buy_label = final_graph_pdf['label'].values
hetero_graph = dgl.heterograph({
('user', 'u_e_ip', 'ip'): (u_e_ip_src, u_e_ip_dst),
('ip', 'u_eby_ip', 'user'): (u_e_ip_dst, u_e_ip_src),
('user', 'u_e_item', 'item'): (u_e_item_src, u_e_item_dst),
('item', 'u_eby_item', 'user'): (u_e_item_dst, u_e_item_src),
('ip', 'ip_e_item', 'item'): (ip_e_item_src, ip_e_item_dst),
('item', 'item_eby_ip', 'ip'): (ip_e_item_dst, ip_e_item_src)
})
# 留意:这儿的代码本文是不需求的
# 这儿是链接猜测使命,是无监督机器学习,不需求标签,这儿没有删除,只是注释起来,便利对比上一篇文章的代码
# 给 user node 增加标签
# hetero_graph.nodes['user'].data['label'] = torch.tensor(user_node_buy_label)
print(hetero_graph)
这儿异构图是无向图,由于无向,所以双向 ,构图的时分就 需求构建双向的边 ,代码很好了解,就不再赘述了哈。 这儿和上文不同的是,这儿是无监督机器学习使命,不需求对用户节点的边进行 label 赋值 。我这儿只是是把注释起来哈。这儿是不需求的。
(2.4) 模型的自界说函数
这儿界说了 异构图上RGCN 会用到的模型的一系列自界说函数,终点看代码注释,结合上文榜首末节的抽象了解,期望你能看了解哦。
@ 欢迎重视微信大众号:算法全栈之路
class RelGraphConvLayer(nn.Module):
def __init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayer, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_bases = num_bases
self.bias = bias
self.activation = activation
self.self_loop = self_loop
# 这个当地只是起到核算的效果, 不保存数据
self.conv = HeteroGraphConv({
# graph conv 里边有模型参数weight,假如外边不传进去的话,里边新建
# 相当于模型加了一层全链接, 对每一种类型的边核算卷积
rel: GraphConv(in_feat, out_feat, norm='right', weight=False, bias=False)
for rel in rel_names
})
self.use_weight = weight
self.use_basis = num_bases < len(self.rel_names) and weight
if self.use_weight:
if self.use_basis:
self.basis = dglnn.WeightBasis((in_feat, out_feat), num_bases, len(self.rel_names))
else:
# 每个联络,又一个weight,全连接层
self.weight = nn.Parameter(th.Tensor(len(self.rel_names), in_feat, out_feat))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
# bias
if bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout = nn.Dropout(dropout)
def forward(self, g, inputs):
g = g.local_var()
if self.use_weight:
weight = self.basis() if self.use_basis else self.weight
# 这每个联络对应一个权重矩阵对应输入维度和输出维度
wdict = {self.rel_names[i]: {'weight': w.squeeze(0)}
for i, w in enumerate(th.split(weight, 1, dim=0))}
else:
wdict = {}
if g.is_block:
inputs_src = inputs
inputs_dst = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
inputs_src = inputs_dst = inputs
# 多类型的边结点卷积完成后的输出
# 输入的是blocks 和 embeding
hs = self.conv(g, inputs, mod_kwargs=wdict)
def _apply(ntype, h):
if self.self_loop:
h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
return self.dropout(h)
#
return {ntype: _apply(ntype, h) for ntype, h in hs.items()}
class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
embed_name='embed',
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
self.embed_name = embed_name
self.activation = activation
self.dropout = nn.Dropout(dropout)
# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(torch.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds[ntype] = embed
def forward(self, block=None):
return self.embeds
class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases=-1,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
if num_bases < 0 or num_bases > len(self.rel_names):
self.num_bases = len(self.rel_names)
else:
self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop
self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvLayer(
self.h_dim, self.h_dim, self.rel_names,
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, weight=False))
# h2h , 这儿不增加隐层,只用2层卷积
# for i in range(self.num_hidden_layers):
# self.layers.append(RelGraphConvLayer(
# self.h_dim, self.h_dim, self.rel_names,
# self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
# dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvLayer(
self.h_dim, self.out_dim, self.rel_names,
self.num_bases, activation=None,
self_loop=self.use_self_loop))
# 输入 blocks,embeding
def forward(self, h=None, blocks=None):
if h is None:
# full graph training
h = self.embed_layer()
if blocks is None:
# full graph training
for layer in self.layers:
h = layer(self.g, h)
else:
# minibatch training
# 输入 blocks,embeding
for layer, block in zip(self.layers, blocks):
h = layer(block, h)
return h
def inference(self, g, batch_size, device="cpu", num_workers=0, x=None):
if x is None:
x = self.embed_layer()
for l, layer in enumerate(self.layers):
y = {
k: th.zeros(
g.number_of_nodes(k),
self.h_dim if l != len(self.layers) - 1 else self.out_dim)
for k in g.ntypes}
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
{k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
# print(input_nodes)
block = blocks[0].to(device)
h = {k: x[k][input_nodes[k]].to(device) for k in input_nodes.keys()}
h = layer(block, h)
for k in h.keys():
y[k][output_nodes[k]] = h[k].cpu()
x = y
return y
上面的代码主要分为三大块:别离是 RelGraphConvLayer、 RelGraphEmbed 以及 EntityClassify。
首要便是:RelGraphConvLayer 。咱们能够看到 RelGraphConvLayer 便是咱们的 异构图卷积层layer, 其主要是调用了DGL完成的 HeteroGraphConv算子,从上面榜首末节咱们也具体论述了 异构图卷积算子其实便是 对各种联络别离进行卷积然后进行同类型的节点的融合。
这儿咱们需求要点重视 的是:RelGraphConvLayer层的回来 ,从代码中,咱们能够看到,关于每种节点类型是回来了一个Embeding, 维度是 out_feat。假如是带了激活函数的,则是回来激活后的一定维度的一个tensor。
过来是 RelGraphEmbed。 从代码中能够看到: 这个python类只是回来了一个字典,可是这个字典里却包含了 多个 Embeding Variable, 留意这儿的 Variable 均是能够 随着网络练习变化更新 的。咱们能够依据节点类型,节点ID取得对应元素的 Embeding 。 这种完成办法是不是处理了前文 GraphSage与DGL完成同构图 Link 猜测,通俗易懂好文强推 和 依据GCN和DGL完成的图上 node 分类, 值得一看!!! 所说到的 动态更新的Embeding 的问题呢。
终究便是 EntityClassify类 了,咱们能够看到 这个便是终究的模型RGCN结构了,包含了模型练习的 forward 和用于揣度的inference办法 。这儿的 inference 能够用于 各个节点的embedding的导出, 咱们在后文有实例代码,接着看下去吧~
留意看 forword 办法里的 for layer, block in zip(self.layers, blocks)
这个方位, 这儿便是咱们前一末节所说的 采样层数和模型的卷积层数目是相同的说法 的由来,能够结合上文阐明了解源码哦。
(2.5) 模型采样超参与边采样介绍
留意: 这儿的采样是边采样,和上一篇文章不同 ,具体看代码。
@ 欢迎重视微信大众号:算法全栈之路
# 依据节点类型和节点ID抽取embeding 参与模型练习更新
def extract_embed(node_embed, input_nodes):
emb = {}
for ntype, nid in input_nodes.items():
nid = input_nodes[ntype]
emb[ntype] = node_embed[ntype][nid]
return emb
# 采样界说
neg_sample_count = 1
batch_size=20480
# 采样2层悉数节点
sampler = MultiLayerFullNeighborSampler(2)
# 边的条数,数目比顶点个数多许多.
# 这是 EdgeDataLoader 数据加载器
hetero_graph.edges['u_e_ip'].data['train_mask'] = torch.zeros(u_e_ip_count, dtype=torch.bool).bernoulli(1.0)
train_ip_eids = hetero_graph.edges['u_e_ip'].data['train_mask'].nonzero(as_tuple=True)[0]
ip_dataloader = EdgeDataLoader(
hetero_graph, {'u_e_ip': train_ip_eids}, sampler, negative_sampler=Uniform(neg_sample_count), batch_size=batch_size
)
hetero_graph.edges['u_e_item'].data['train_mask'] = torch.zeros(u_e_item_count, dtype=torch.bool).bernoulli(1.0)
train_item_eids = hetero_graph.edges['u_e_item'].data['train_mask'].nonzero(as_tuple=True)[0]
item_dataloader = EdgeDataLoader(
hetero_graph, {'u_e_item': train_item_eids}, sampler, negative_sampler=Uniform(neg_sample_count), batch_size=batch_size
)
这儿的代码作者花了很多时刻进行优化,注释和组织形式 尽量写的十分清晰,十分简单了解。
咱们这儿挑选了 EdgeDataLoader 来进行练习数据的读入,这其实是一种分batch练习 的办法,而不是一次性把图全读入内存进行练习,而是每次挑选batch的种子节点 真实边构成的pos graph 以及 种子节点和他们随机采样的大局节点组成的neg graph 读入内存参与练习,这也让大的图神经网络练习成为了可能,是 DGL图深度结构 十分优异的完成 !!! 大赞 !
这儿 EdgeDataLoader
采样算法也和 negative_sampler 与 sampler 结合运用,其间 sampler 采样了2层悉数街坊作为正样本,而 negative_sampler 则是对不存在的边进行构建,起点也是种子节点,而终点则是 大局随机采样得到的 。
留意读者在里的边采样能够上一篇文章 重磅好文透彻了解,异构图上 Node 分类理论与DGL源码实战 中的节点采样对比查看,能够加深了解哦~
(2.6) 模型结构界说与 丢失函数阐明
三个类的办法界说,和节点分类使命有差异的当地,能够看看~
@ 欢迎重视微信大众号:算法全栈之路
# Define a Heterograph Conv model
class Model(nn.Module):
def __init__(self, graph, hidden_feat_dim, out_feat_dim):
super().__init__()
self.rgcn = EntityClassify(graph,
hidden_feat_dim,
out_feat_dim)
self.pred = HeteroDotProductPredictor()
def forward(self, h, pos_g, neg_g, blocks, etype):
h = self.rgcn(h, blocks)
return self.pred(pos_g, h, etype), self.pred(neg_g, h, etype)
class MarginLoss(nn.Module):
def forward(self, pos_score, neg_score):
# 求丢失的平均值 , view 改动tensor 的形状
# 1- pos_score + neg_score ,应该是 -pos 符号越大变成越小 +neg_score 越小越好
return (1 - pos_score + neg_score.view(pos_score.shape[0], -1)).clamp(min=0).mean()
class HeteroDotProductPredictor(nn.Module):
def forward(self, graph, h, etype):
# 在核算之外更新h,保存为大局可用
# h contains the node representations for each edge type computed from node_clf_hetero.py
with graph.local_scope():
graph.ndata['h'] = h # assigns 'h' of all node types in one shot
graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
return graph.edges[etype].data['score']
这儿的三个类函数 Model、MarginLoss、HeteroDotProductPredictor 均是十分重要的。
首要是 model , 咱们能够看到 这儿的model 别离引入了 EntityClassify 和 HeteroDotProductPredictor
,这两个函数别离界说了 模型的结构与丢失 。EntityClassify 和 上一文 介绍的一模一样,这儿不在赘述了。
接着是 MarginLoss ,能够看到 MarginLoss 便是咱们前文讲过的依据 同源性假定 设计的丢失,HeteroDotProductPredictor
则是依据两头节点信息 核算边是否存在 的函数,能够 从同构图揣度到异构图 中去,和 GraphSage与DGL完成同构图 Link 猜测,通俗易懂好文强推 中一样,本文也不在进行赘述 了。
(2.7) 模型练习超参与单epoch练习
代码是表达程序员思想的最好言语,直接看代码吧!
@ 欢迎重视微信大众号:算法全栈之路
# in_feats = hetero_graph.nodes['user'].data['feature'].shape[1]
hidden_feat_dim = n_hetero_features
out_feat_dim = n_hetero_features
embed_layer = RelGraphEmbed(hetero_graph, hidden_feat_dim)
all_node_embed = embed_layer()
model = Model(hetero_graph, hidden_feat_dim, out_feat_dim)
# 优化模型一切参数,主要是weight以及输入的embeding参数
all_params = itertools.chain(model.parameters(), embed_layer.parameters())
optimizer = torch.optim.Adam(all_params, lr=0.01, weight_decay=0)
loss_func = MarginLoss()
def train_etype_one_epoch(etype, spec_dataloader):
losses = []
# input nodes 为 采样的subgraph中的一切的节点的调集
for input_nodes, pos_g, neg_g, blocks in tqdm.tqdm(spec_dataloader):
emb = extract_embed(all_node_embed, input_nodes)
pos_score, neg_score = model(emb, pos_g, neg_g, blocks, etype)
loss = loss_func(pos_score, neg_score)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('{:s} Epoch {:d} | Loss {:.4f}'.format(etype, epoch, sum(losses) / len(losses)))
这儿咱们界说了模型结构,丢失选用的是 上文界说的 MarginLoss , 这儿需求留意的是 spec_dataloader
回来值,这儿是 边采样 ,回来和节点采样的dataloader是 不一样的。
其他的代码十分简单了解,有问题欢迎去大众号联络讨论~
(2.8) 模型多种节点练习
@ 欢迎重视微信大众号:算法全栈之路
# 开端train 模型
for epoch in range(1):
print("start epoch:", epoch)
model.train()
train_etype_one_epoch('u_e_ip', ip_dataloader)
train_etype_one_epoch('u_e_item', item_dataloader)
从代码中咱们能够知道:
关于异构图,其实咱们也是以 各种类型的节点作为种子节点, 然后进行图上的负边采样,别离进行练习然后更新整个模型结构 的。
(2.7) 模型保存与节点Embeding导出 (和上文相同)
@ 欢迎重视微信大众号:算法全栈之路
# 图数据和模型保存
save_graphs("graph.bin", [hetero_graph])
torch.save(model.state_dict(), "model.bin")
# 每个结点的embeding,自己初始化,由于参与了练习,这个便是终究每个结点输出的embeding
print("node_embed:", all_node_embed['user'][0])
# 模型预估的成果,终究应该运用 inference,这儿得到的是logit
# 留意,这儿传入 all_node_embed,挑选0,选1可能会死锁,终究程序不履行
inference_out = model.inference(hetero_graph, batch_size, 'cpu', num_workers=0, all_node_embed)
print(inference_out["user"].shape)
print(inference_out['user'][0])
这儿咱们能够看到, 咱们运用了 model.inference
接口进行模型的 节点 Embeding导出 。
这儿需求留意的是: 这个当地 num_workers应该设置0 ,即为不用多线程, 不然会互锁, 导致预估使命不履行 。
这儿是深坑啊,反正经过很长时刻的纠结和查找,终究发现是这个原因,期望读者能够防止遇到类似的问题 ~
到这儿,异构图 Link 猜测 理论与DGL 源码实战 的全文就写完了。这一篇文章是为了 图系列文章 的 完整性 而写的一篇文章。信任认真看过作者文章的人,每一篇都不错失的话,到这儿修改下网络对他们来说是十分简单的事情。
可是事实上,也确有同学卡在了 异构图链接猜测的一些自界说函数 上,不知道怎么去完成来进行 链接猜测 使命,那就结合本文与上一篇文章以及曾经的一篇同构图链接猜测的文章一同看看吧,信任你会有很有收获的 ~
上面的代码demo 在环境没问题的情况下,悉数 复制到一个python文件 里,就能够完美运转起来。本文的代码是一个 小型的商业能够用 的工程项目,期望能够对你有参阅效果 ~
码字不易,觉得有收获就动动小手转载一下吧,你的支持是我写下去的最大动力 ~
更多更全更新内容,欢迎重视作者的大众号: 算法全栈之路