重磅好文透彻了解,异构图上 Node 分类理论与DGL源码实战


书接上文,重视过作者历史文章的读者都知道,图上机器学习/深度学习系列文章 从 一文揭开图机器学习的面纱,你确认不来看看吗 开始,现已陆续和咱们一同了解了 同构图上的链接猜测、节点分类与回归、边分类与回归 等机器学习使命,不熟悉的同学能够去作者的历史文章里查找哦。

如上所说,曾经介绍 图上机器学习使命 的文章, 均是在 同构图 上进行的,疏忽了图上不同节点以及不同边的共同性质,而是把一切节点当作一种节点来看待的。这个尽管能够处理一部分问题,可是该联系建模才能也不足以掩盖现实世界 中复杂多变的多种联系,所以就轮到咱们的 异构图联系建模 文章出马了。

针对 异构图 上联系的建模,由于其 工程完结的复杂性 ,现在的学术界和工业界均存在必定的 完结难度 。我知道的甚至很多图深度学习结构在最新的版本里还 不支持 对异构图的建模。好在亚马逊的DGL结构在最新的几个版本中,现已更新了对异构图的工程完结,下面就让咱们结合DGL的完结源码来一同了解下 异构图上节点分类/回归使命 吧 ~ go go go !!!

留意:咱们的文章里,把分类回归使命一同包含了由于这来那个使命除了 输入和丢失不同 以外,网络结构并没有其他不同,分类回归使命相互修改互用也比较简单,这儿就不再进行区分了。本文说是节点分类使命,可是其实回归使命也差不太多。


(1) 异构图节点分类使命理论基础

依照常规,咱们还是先从基础界说引出下文的话题。

在曾经的文章 一文揭开图机器学习的面纱,你确认不来看看吗 中,咱们说图的分类的时分说到了异构图,文中说:图中节点类型和边类型超越两种的图称为异构图。这意思便是说异构图中的节点和同2个节点的边可能有多种,例如:图中包含用户,产品,IP三种类型的节点,其间用户和产品之间又有加购物车与购买这两种联系的边。本文所说的图便是这种类型的 比较复杂 的图。

同构图推广 来看,已然在异构图中区分了 节点和边 的不同类型,那咱们在处理依据 异构图的局部与全局结构特性 对某个节点进行 定性分析 或则 进行两个节点之间 联系猜测 的时分,就需求从 更细粒度 上去对不同的节点和边的联系进行 区分 。已然2个节点的某一种联系决定了一种类型的边,一种比较好的方式是: 依据联系(边)类型去安排不同类型的节点 ,然后进行异构图卷积操作,得到对各个类型的节点的 Embeding,在依据此终究完结 异构图上的机器学习使命 ,就像DGL官方源码完结的那样。

所谓 异构图卷积,顾名思义: 便是对 各种边的联系各自别离进行卷积 ,然后将这些联系对应的各种类型的同类型节点进行交融,默许是Sum , 得到各种同类型节点的Embeding, 留意这儿每种类型节点只有一个Embeding。 关于 节点分类 使命,终究在异构图卷积层完毕的时分,能够直接接激活函数,然后别离对每种类型的节点核算出一个Logit, 和有监督的某种类型的 label 核算丢失进行回传即可。感兴趣的同学,能够看 DGL完结的RGCB节点分类使命的源码验证明晰 以上所说的逻辑。

这儿需求特别强调留意 的是: 在异构图RGCN采样的时分,采样了几层街坊节点,异构图卷积层就有几层异构卷积layer, 别离有每个异构卷积layer去处理每一层的街坊节点

由于采样是由内向外采样的,而聚合是由外向内聚合的。这儿要引进DGL完结采样得到的Block的概念,浅显了解 Block其实便是采样得到的子图,而这些子图里的边也有对应这开始节点和完毕节点以及边类型等和 全Graph同等 的一些特点

咱们能够这样了解DGL完结的Block能够把看作一个数组,数组里的每一个元素是图上一层街坊的采样,Block内部节点是 从远到近的顺序排列内部的Block的,Block数组的下标从小到大对应着采样规模由外到内、掩盖规模由远及近,而且 blocks[i+1]的 source node 和 blocks[i]的target node是能够对应上的。咱们知道街坊节点采样其实是依照边的联系去采来确认街坊的,所以在DGL的采样过程中,让 blocks[0]的 src node 包含了 blocks[0]的一切dst node,而且dst 节点出现在src 节点序列的前面若干位置

所以咱们在代码完结的时分,将 外层对应节点的Embeding作为内层节点的输入,构成两个相互挨着的卷积层 ,这儿采样与工程完结是 完美相互契合 的。有疑问的同学,能够去看源码验证哦 ~

好吧,全体对异构图的节点分类使命 抽象 一下: 已然咱们要对异构图上某节点进行分类,那咱们就需求归纳异构图上该节点街坊节点的信息,得出所求节点的Embeding 信息。 而该节点周围有多种类别联系的节点,则咱们就对各个联系别离进行卷积,求得各个联系里边各个节点的Embeding, 然后将多种联系包含的多类同类节点 Embeding进行聚合,后面能够接全链接层,也能够不接全链接层直接接激活函数,得到各个节点类型的成果作为输出。关于异构图,终究 节点分类使命的 Logit 也是 依照节点类其他个数有多个

当然针对异构图,咱们能够选用 GraphSage还是HAN ?吐血力作综述Graph Embeding 经典好文 文章后半部分里介绍的,运用 MetaPath 结合 Attention 进行 Node 节点等级 与 path语义级其他交融,类似于 HAN 的处理方式。可是 万丈高楼平地起写代码和写文章,也得慢慢来一点一点儿完结不是~

异构图RGCN节点分类使命 全体的流程解析就到这儿吧,感觉这个当地,还是得看源码才能说清楚。由于整个源码流程比较长,也为了让终究整个代码demo能够完美的运转起来,本篇文章的代码将从 讲述一个工程的完结 开始。

所以,本文 就让咱们一同完结 依据DGL和异构图的RGCN来进行节点分类回归使命 。下面就让咱们开始 coding 吧 ~


(2) 代码时光

开篇先吼一嗓子 , talk is cheap , show me the code !!!

本文的代码讲的是 依据DGL和RGCN完结的异构图上节点分类使命,整个源码流程是一个 小型的工业可用的工程,依据dgl完结,觉得有用赶忙保藏转发吧~

life is short , i use python !!!

(2.1) 数据预备

咱们假定能够输入类似于这样的数据, 其间每2列对应这一种联系,例如 用户2352193 购买了产品CEEC9EBF7,用户用了IP 174.74.201.9登录了账号,用户用IP 174.74.201.9 购买了产品 CEEC9EBF7, label 表明着该用户真的购买产品,终究的节点分类使命是猜测用户的购买意愿,是否是咱们的高意图潜在用户,二分类。

咱们能够把这样一份数据存入 source_data.csv 文件中,用 pandas 接口把数据读入: raw_pdf = pd.read_csv('./source_data.csv')

重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战

由于关于 异构图 模型,节点和边的类型均有多种,为了处理便利,咱们能够把各种类型的节点进行编码,再到后期对其进行解码,对 pandas 的 dataframe 数据结构的编解码,咱们能够运用下面的代码:

@欢迎重视微信大众号:算法全栈之路
#编码办法
defencode_map(input_array):
p_map={}
length=len(input_array)
forindex,eleinzip(range(length),input_array):
#print(ele,index)
p_map[str(ele)]=index
returnp_map
#解码办法
defdecode_map(encode_map):
de_map={}
fork,vinencode_map.items():
#index,ele
de_map[v]=k
returnde_map

然后用其间的各列node 进行 编码

@欢迎重视微信大众号:算法全栈之路
userid_encode_map=encode_map(set(graph_features_pdf['user_id'].values))
#解码map
userid_decode_map=decode_map(userid_encode_map)
graph_features_pdf['user_id_encoded']=graph_features_pdf['user_id'].apply(lambdae:userid_encode_map.get(str(e),-1))
#printunique值的个数
userid_count=len(set(graph_features_pdf['user_id_encoded'].values))
print(userid_count)

这儿仅仅以 用户节点编码 为例,itemId和 IP同理编解码即可。 终究咱们能够把图数据保存,供今后的异构图代码 demo运用。

@欢迎重视微信大众号:算法全栈之路
final_graph_pdf=graph_features_pdf[['user_id_encoded','ip_encoded','item_id_encoded','label']].sort_values(by='user_id_encoded',ascending=True)
final_graph_pdf.to_csv('result_label.csv',index=False)

依据此,异构图的基础预备数据就完毕了,下面开始正式的coding了。


(2.2) 导包

老规矩,先导包,依据DGL和RGCN完结的异构图上节点分类使命只需求这些包就能够了。

@欢迎重视微信大众号:算法全栈之路
importargparse
importtorch
importtorch.nnasnn
importdgl
importtorch.optimasoptim
fromdgl.dataloadingimportMultiLayerFullNeighborSampler,EdgeDataLoader
fromdgl.dataloading.negative_samplerimportUniform
importnumpyasnp
importpandasaspd
importitertools
importos
importtqdm
fromdglimportsave_graphs,load_graphs
importdgl.functionasfn
importtorch
importdgl
importtorch.nn.functionalasF
fromdgl.nn.pytorchimportGraphConv,SAGEConv,HeteroGraphConv
fromdgl.utilsimportexpand_as_pair
importtqdm
fromcollectionsimportdefaultdict
importtorchasth
importdgl.nnasdglnn
fromdgl.data.utilsimportmakedirs,save_info,load_info
fromsklearn.metricsimportroc_auc_score
importgc
gc.collect()

推荐一个东西,tqdm 很好用 哦,结合 dataloading接口 , 能够看到模型练习以及数据处理履行的进度,赶忙用起来吧~

这儿的 sklearn 东西 的导入,仅仅是为了调用他来进行分类模型的离线目标评价,得到AUC等目标罢了。

各种模型东西无所谓分类,能处理问题的便是好东西,混用又有何不可呢? 有用就行


(2.3) 构图

数据有了,接下来便是构图了,咱们构建的是包含 三种节点的异构图

@欢迎重视微信大众号:算法全栈之路
#user登录ip
u_e_ip_src=final_graph_pdf['user_id_encoded'].values
u_e_ip_dst=final_graph_pdf['ip_encoded'].values
#user购买item
u_e_item_src=final_graph_pdf['user_id_encoded'].values
u_e_item_dst=final_graph_pdf['item_id_encoded'].values
#item和ip共同出现
ip_e_item_src=final_graph_pdf['ip_encoded'].values
ip_e_item_dst=final_graph_pdf['item_id_encoded'].values
#user购买label
user_node_buy_label=final_graph_pdf['label'].values
hetero_graph=dgl.heterograph({
('user','u_e_ip','ip'):(u_e_ip_src,u_e_ip_dst),
('ip','u_eby_ip','user'):(u_e_ip_dst,u_e_ip_src),
('user','u_e_item','item'):(u_e_item_src,u_e_item_dst),
('item','u_eby_item','user'):(u_e_item_dst,u_e_item_src),
('ip','ip_e_item','item'):(ip_e_item_src,ip_e_item_dst),
('item','item_eby_ip','ip'):(ip_e_item_dst,ip_e_item_src)
})
#给usernode增加标签
hetero_graph.nodes['user'].data['label']=torch.tensor(user_node_buy_label)
print(hetero_graph)

这儿的 异构图是 无向图 ,由于无向,所以双向。 构图的时分就需求构建 双向的边。 代码很好了解,就不再赘述了哈。


(2.4) 模型的自界说函数

这儿界说了 异构图上RGCN 会用到的模型的一系列自界说函数,归纳看代码注释,结合上文榜首末节的抽象了解,期望能了解的愈加深入哦。

@欢迎重视微信大众号:算法全栈之路
classRelGraphConvLayer(nn.Module):
def__init__(self,
in_feat,
out_feat,
rel_names,
num_bases,
*,
weight=True,
bias=True,
activation=None,
self_loop=False,
dropout=0.0):
super(RelGraphConvLayer,self).__init__()
self.in_feat=in_feat
self.out_feat=out_feat
self.rel_names=rel_names
self.num_bases=num_bases
self.bias=bias
self.activation=activation
self.self_loop=self_loop
#这个当地仅仅起到核算的作用,不保存数据
self.conv=HeteroGraphConv({
#graphconv里边有模型参数weight,假如外边不传进去的话,里边新建
#相当于模型加了一层全链接,对每一种类型的边核算卷积
rel:GraphConv(in_feat,out_feat,norm='right',weight=False,bias=False)
forrelinrel_names
})
self.use_weight=weight
self.use_basis=num_bases<len(self.rel_names)andweight
ifself.use_weight:
ifself.use_basis:
self.basis=dglnn.WeightBasis((in_feat,out_feat),num_bases,len(self.rel_names))
else:
#每个联系,又一个weight,全连接层
self.weight=nn.Parameter(th.Tensor(len(self.rel_names),in_feat,out_feat))
nn.init.xavier_uniform_(self.weight,gain=nn.init.calculate_gain('relu'))
#bias
ifbias:
self.h_bias=nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)
#weightforselfloop
ifself.self_loop:
self.loop_weight=nn.Parameter(th.Tensor(in_feat,out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))
self.dropout=nn.Dropout(dropout)
defforward(self,g,inputs):

g=g.local_var()
ifself.use_weight:
weight=self.basis()ifself.use_basiselseself.weight
#这每个联系对应一个权重矩阵对应输入维度和输出维度
wdict={self.rel_names[i]:{'weight':w.squeeze(0)}
fori,winenumerate(th.split(weight,1,dim=0))}
else:
wdict={}
ifg.is_block:
inputs_src=inputs
inputs_dst={k:v[:g.number_of_dst_nodes(k)]fork,vininputs.items()}
else:
inputs_src=inputs_dst=inputs
#多类型的边结点卷积完结后的输出
#输入的是blocks和embeding
hs=self.conv(g,inputs,mod_kwargs=wdict)
def_apply(ntype,h):
ifself.self_loop:
h=h+th.matmul(inputs_dst[ntype],self.loop_weight)
ifself.bias:
h=h+self.h_bias
ifself.activation:
h=self.activation(h)
returnself.dropout(h)
#
return{ntype:_apply(ntype,h)forntype,hinhs.items()}
classRelGraphEmbed(nn.Module):
r"""Embeddinglayerforfeaturelessheterograph."""
def__init__(self,
g,
embed_size,
embed_name='embed',
activation=None,
dropout=0.0):
super(RelGraphEmbed,self).__init__()
self.g=g
self.embed_size=embed_size
self.embed_name=embed_name
self.activation=activation
self.dropout=nn.Dropout(dropout)
#createweightembeddingsforeachnodeforeachrelation
self.embeds=nn.ParameterDict()
forntypeing.ntypes:
embed=nn.Parameter(torch.Tensor(g.number_of_nodes(ntype),self.embed_size))
nn.init.xavier_uniform_(embed,gain=nn.init.calculate_gain('relu'))
self.embeds[ntype]=embed
defforward(self,block=None):

returnself.embeds
classEntityClassify(nn.Module):
def__init__(self,
g,
h_dim,out_dim,
num_bases=-1,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify,self).__init__()
self.g=g
self.h_dim=h_dim
self.out_dim=out_dim
self.rel_names=list(set(g.etypes))
self.rel_names.sort()
ifnum_bases<0ornum_bases>len(self.rel_names):
self.num_bases=len(self.rel_names)
else:
self.num_bases=num_bases
self.num_hidden_layers=num_hidden_layers
self.dropout=dropout
self.use_self_loop=use_self_loop
self.embed_layer=RelGraphEmbed(g,self.h_dim)
self.layers=nn.ModuleList()
#i2h
self.layers.append(RelGraphConvLayer(
self.h_dim,self.h_dim,self.rel_names,
self.num_bases,activation=F.relu,self_loop=self.use_self_loop,
dropout=self.dropout,weight=False))
#h2h,这儿不增加隐层,只用2层卷积
#foriinrange(self.num_hidden_layers):
#self.layers.append(RelGraphConvLayer(
#self.h_dim,self.h_dim,self.rel_names,
#self.num_bases,activation=F.relu,self_loop=self.use_self_loop,
#dropout=self.dropout))
#h2o
self.layers.append(RelGraphConvLayer(
self.h_dim,self.out_dim,self.rel_names,
self.num_bases,activation=None,
self_loop=self.use_self_loop))
#输入blocks,embeding
defforward(self,h=None,blocks=None):
ifhisNone:
#fullgraphtraining
h=self.embed_layer()
ifblocksisNone:
#fullgraphtraining
forlayerinself.layers:
h=layer(self.g,h)
else:
#minibatchtraining
#输入blocks,embeding
forlayer,blockinzip(self.layers,blocks):
h=layer(block,h)
returnh

definference(self,g,batch_size,device="cpu",num_workers=0,x=None):
ifxisNone:
x=self.embed_layer()
forl,layerinenumerate(self.layers):
y={
k:th.zeros(
g.number_of_nodes(k),
self.h_dimifl!=len(self.layers)-1elseself.out_dim)
forking.ntypes}

sampler=dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader=dgl.dataloading.NodeDataLoader(
g,
{k:th.arange(g.number_of_nodes(k))forking.ntypes},
sampler,
batch_size=batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)

forinput_nodes,output_nodes,blocksintqdm.tqdm(dataloader):
#print(input_nodes)
block=blocks[0].to(device)

h={k:x[k][input_nodes[k]].to(device)forkininput_nodes.keys()}
h=layer(block,h)
forkinh.keys():
y[k][output_nodes[k]]=h[k].cpu()
x=y
returny

上面的代码首要分为三大块:别离是 RelGraphConvLayerRelGraphEmbed 以及 EntityClassify

首先便是:RelGraphConvLayer 。咱们能够看到 RelGraphConvLayer 便是咱们的 异构图卷积层layer , 其首要是调用了DGL完结的 HeteroGraphConv算子,从上面榜首末节咱们也具体论述了异构图卷积算子其实便是: 对各种联系别离进行卷积然后进行同类型的节点的交融

这儿咱们需求重点重视的是:RelGraphConvLayer层的回来,从代码中,咱们能够看到,关于每种节点类型是回来了一个Embeding, 维度是 out_feat。假如是带了激活函数的,则是回来激活后的必定维度的一个tensor。

过来是 RelGraphEmbed。 从代码中能够看到: 这个python类仅仅回来了一个字典,可是这个字典里却包含了 多个 Embeding Variable, 留意这儿的 Variable 均是能够 随着网络练习变化更新 的。咱们能够依据节点类型,节点ID获得对应元素的 Embeding 。 这种完结办法是不是处理了 前文 GraphSage与DGL完结同构图 Link 猜测,浅显易懂好文强推 和 依据GCN和DGL完结的图上 node 分类, 值得一看!!! 所说到的 动态更新的Embeding 的问题呢。

终究便是 EntityClassify类 了,咱们能够看到 这个便是终究的 模型RGCN结构 了,包含了 模型练习的 forward 和用于揣度的inference办法

。这儿的 inference 能够用于 各个节点的embedding的导出, 咱们在后文有实例代码,接着看下去吧~

留意看 forword 办法里 的 for layer, block in zip(self.layers, blocks) 这个位置, 这儿便是咱们前一末节所说的 采样层数和模型的卷积层数目是相同的说法的由来,能够结合上文说明了解源码哦。


(2.5) 模型采样超参加节点采样介绍

先上代码。

@欢迎重视微信大众号:算法全栈之路
#依据节点类型和节点ID抽取embeding参加模型练习更新
defextract_embed(node_embed,input_nodes):
emb={}
forntype,nidininput_nodes.items():
nid=input_nodes[ntype]
emb[ntype]=node_embed[ntype][nid]
returnemb
#采样界说,有监督采样和无监督采样不一样
batch_size=20480
neg_sample_count=1
#采样2层悉数节点
sampler=MultiLayerFullNeighborSampler(2)
#用户节点采样,这儿是对用户的一切街坊采样了2层节点
hetero_graph.nodes['user'].data['train_mask']=torch.zeros(unique_userid_count,dtype=torch.bool).bernoulli(1.0)
all_userid_idx=torch.nonzero(hetero_graph.nodes['user'].data['train_mask'],as_tuple=False).squeeze()
user_loader=dgl.dataloading.NodeDataLoader(hetero_graph,{"user":train_userid_nodeids},sampler,batch_size=batch_size,shuffle=True,num_workers=0)
#练习集和测试集split
train_count=(int)(len(all_userid_idx)*0.9)
print(train_count)
train_userid_nodeids=all_userid_idx[:train_count]
test_userid_nodeids=all_userid_idx[train_count:]
#IP节点的街坊采样
hetero_graph.nodes['ip'].data['train_mask']=torch.zeros(unique_ip_count,dtype=torch.bool).bernoulli(1.0)
train_ip_nodeids=hetero_graph.nodes['ip'].data['train_mask'].nonzero(as_tuple=True)[0]
ip_loader=dgl.dataloading.NodeDataLoader(hetero_graph,{"ip":train_ip_nodeids},sampler,
batch_size=batch_size,shuffle=True,num_workers=0)
#item街坊节点采样
hetero_graph.nodes['item'].data['train_mask']=torch.zeros(unique_ip_prefix_count,dtype=torch.bool).bernoulli(1.0)
train_ipprefix_nodeids=hetero_graph.nodes['item'].data['train_mask'].nonzero(as_tuple=True)[0]
ipprefix_loader=dgl.dataloading.NodeDataLoader(hetero_graph,{"item":train_ipprefix_nodeids},sampler,batch_size=batch_size,shuffle=True,num_workers=0)

这儿的代码作者花了大量时间进行优化注释和安排形式 尽量写的十分明晰,十分简单了解。

咱们这儿挑选了 NodeDataLoader 来进行练习数据的读入,这其实是一种 分batch练习 的办法,而 不是一次性把图全读入内存 进行练习,而是每次挑选 batch的种子节点以及他们采样的街坊节点 读入内存参加练习,这也让大的图神经网络练习成为了可能,是 DGL图深度结构 十分优异 的完结 !!! 大赞 !

需求 留意的是 : extract_embed 这个办法能够抽取出对应类别对应节点的 Embeding。 咱们这儿用了 MultiLayerFullNeighborSampler 这个接口,对每个种子节点采样了2层的悉数街坊参加练习,中心由于是节点分类使命,这儿需求将该街坊采样算子 和 dgl.dataloading.NodeDataLoader 结合运用。

NodeDataLoader 的第二个参数属于一个字典,其间能够放多个 节点类型以及对应的种子nids , 这儿为了便利了解,把拆解成了多个 data_loader,来别离对多个类型的节点在图上进行悉数街坊的采样,这儿的 完结是等价 的。

作者亲测,图练习的 batch_size 能挑选大尽可能大一些 吧,不然练习模型会十分慢的~


(2.6) 模型练习超参加单epoch练习

@欢迎重视微信大众号:算法全栈之路
#模型界说
num_class=2
n_hetero_features=16
labels=hetero_graph.nodes['user'].data['label']
hidden_feat_dim=n_hetero_features
embed_layer=RelGraphEmbed(hetero_graph,hidden_feat_dim)
all_node_embed=embed_layer()
model=EntityClassify(hetero_graph,hidden_feat_dim,num_class)
#优化模型一切参数,首要是weight以及输入的embeding参数
all_params=itertools.chain(model.parameters(),embed_layer.parameters())
optimizer=torch.optim.Adam(all_params,lr=0.01,weight_decay=0)
deftrain_nodetype_one_epoch(ntype,spec_dataloader):
losses=[]
#input_nodes代表核算output_nodes的表明所需的节点,input_nodes包含了output_nodes。
#块包含了每个GNN层要核算哪些节点表明作为输出,要将哪些节点表明作为输入,以及来自输入节点的表明如何传播到输出节点。
forinput_nodes,output_nodes,blocksintqdm.tqdm(spec_dataloader):
emb=extract_embed(all_node_embed,input_nodes)
batch_tic=time.time()
seeds=output_nodes[ntype]
lbl=labels[seeds]#只取output_nodes部分结点参加练习
logits=model(emb,blocks)[ntype]

loss=F.cross_entropy(logits,lbl)
loss.backward()
optimizer.step()

train_acc=torch.sum(logits.argmax(dim=1)==lbl).item()/len(seeds)

print('AUC',roc_auc_score(lbl,logits.argmax(dim=1)))
print("Epoch{:05d}|TrainAcc:{:.4f}|TrainLoss:{:.4f}|Time:{:.4f}".
format(epoch,train_acc,loss.item(),time.time()-batch_tic))

从上面的代码咱们能够看到: 终究咱们是进行了 2分类 ,中心的调用了上面模型界说类 EntityClassify 来界说 异构图上RGCN的模型 结构,由于是分类问题,丢失函数挑选了 穿插熵丢失

需求留意的是: all_params = itertools.chain(model.parameters(), embed_layer.parameters()) 这一行代码,咱们界说优化器的参数时,将咱们自界说的 可随网络更新的 Variable 加入了 itertools.chain 参加模型的练习。

另一个需求留意的点是: spec_dataloader 这个当地,它的回来是 input_nodes, output_nodes和 blocks 这三个元素的tuple 。 其间,input_nodes 代表核算 output_nodes 的表明所需的节点,input_nodes包含了output_nodes。块 包含了每个GNN层要核算哪些节点表明作为输出,要将哪些节点表明作为输入,以及来自输入节点的表明如何传播到输出节点

这就有了咱们进行模型练习所需求的图上结构的悉数信息了。


(2.6) 模型多种节点练习

@欢迎重视微信大众号:算法全栈之路
#开始train模型
forepochinrange(20):
print("startepoch:",epoch)
model.train()
train_nodetype_one_epoch('user',user_loader)
train_nodetype_one_epoch('user',user_loader)
train_nodetype_one_epoch('user',user_loader)

从代码中咱们能够知道: 关于异构图,其实咱们也是以 各种类型的节点作为种子节点, 然后进行图上的街坊采样,别离进行练习然后更新整个模型结构 的。


(2.7) 模型保存与节点Embeding导出

@欢迎重视微信大众号:算法全栈之路
#图数据和模型保存
save_graphs("graph.bin",[hetero_graph])
torch.save(model.state_dict(),"model.bin")
#每个结点的embeding,自己初始化,由于参加了练习,这个便是终究每个结点输出的embeding
print("node_embed:",all_node_embed['user'][0])
#模型预估的成果,终究应该运用inference,这儿得到的是logit
#留意,这儿传入all_node_embed,挑选0,选1可能会死锁,终究程序不履行
inference_out=model.inference(hetero_graph,batch_size,'cpu',num_workers=0,all_node_embed)
print(inference_out["user"].shape)
print(inference_out['user'][0])

这儿咱们能够看到, 咱们运用了 model.inference 接口进行模型的节点 Embeding导出。

这儿需求留意的是: 这个当地 num_workers应该设置0 ,即为不用多线程, 不然会互锁,导致预估使命不履行。这儿是 深坑 啊,横竖经过很长时间的纠结和查找,终究发现是这个原因,期望读者能够防止遇到类似的问题 ~

其实关于异构图,要写出对它的一些使用的了解,我也是怯生生的。可是,凡事必先骑上虎背 。管它呢,上吧,能写到哪一步是哪一步吧! 欢迎重视作者并留言和我一同讨论,相互一同学习交流 ~

到这儿,重磅好文透彻了解, 异构图上 Node 分类理论与DGL源码实战 的全文就写完了。上面的代码demo 在环境没问题的情况下,悉数复制到一个python文件里,就能够完美运转起来。本文的 代码是一个小型的商业能够用的工程项目,期望能够对你有参考作用 ~


码字不易,觉得有收获就动动小手转载一下吧,你的支持是我写下去的最大动力 ~

更多更全更新内容,欢迎重视作者的大众号: 算法全栈之路

重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战

  • END –