- 相关论文:A deep graph neural network architecture for modelling spatio-temporal dynamics in resting-state functional MRI data
- 相关repo:github.com/tjiagoM/spa…
- 笔记人:陈亦新
主函数中生成了这样的模型:
model = SpatioTemporalModel(run_cfg=run_cfg,
encoding_model=None
).to(run_cfg['device_run'])
这个SpatioTemporalModel十分的长,和以前解读工程一样,咱们只看forward函数就行,下面片段中的注释为我的理解:
class SpatioTemporalModel(nn.Module):
def forward(self, data):
# 这儿的三个数据,和咱们在上一末节解说的共同
x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
if self.multimodal_size > 0:
xn, x = x[:, :self.multimodal_size], x[:, self.multimodal_size:]
xn = self.multimodal_lin(xn)
xn = self.activation(xn)
xn = self.multimodal_batch(xn)
xn = F.dropout(xn, p=self.dropout, training=self.training)
# Processing temporal part
if self.conv_strategy != ConvStrategy.NONE:
# 这儿似乎是吧LSTM也理解为Conv了
if self.conv_strategy == ConvStrategy.LSTM:
# 采用LSTM作为特征提取的办法
x = x.view(-1, self.num_time_length, 1)
# 能够见下面的LSTM-弥补1,便是用0初始化LSTM的隐含特征和cell state
h0, c0 = self.init_lstm_hidden(x)
# 可见下面LSTM-弥补2,一个LSTM模块
x, (_, _) = self.temporal_conv(x, (h0, c0))
x = x.contiguous()
else:
# 不是LSTM,那么便是卷积策略了。这儿卷积策略包含了一般的1D卷积,也包含了TCN的1D卷积模型。可见下方CNN-弥补1和TCN-弥补1
x = x.view(-1, 1, self.num_time_length)
x = self.temporal_conv(x)
# Concatenating for the final embedding per node
# 这个变量self.size_before_lin_temporal的数值,卷积通道x时刻序列长度。这时分卷积通道数现已扩大了8倍,时刻序列长度现已下采样了4次,变成本来的16分之1了。
x = x.view(x.size()[0], self.size_before_lin_temporal)
# 是一个全衔接层,也可能从_get_lin_temporal函数中得到的组件,详情能够看到下面的办法_get_lin_temporal
x = self.lin_temporal(x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
elif self.encoding_strategy == EncodingStrategy.STATS:
# 全衔接层self.stats_lin+1D BN层
x = self.stats_lin(x)
x = self.activation(x)
x = self.stats_batch(x)
x = F.dropout(x, p=self.dropout, training=self.training)
elif self.encoding_strategy == EncodingStrategy.VAE3layers:
# 这个也简单,便是VAE自编码器来做的特征提取
mu, logvar = self.encoder_model.encode(x)
x = self.encoder_model.reparameterize(mu, logvar)
elif self.encoding_strategy == EncodingStrategy.AE3layers:
# 和上面类似,是autoENcoder的
x = self.encoder_model.encode(x)
if self.multimodal_size > 0:
x = torch.cat((xn, x), dim=1)
# 到这一步的时分,咱们的x是现已从ts当中提取好的特征。
# 图网络用了两个经典中的经典,GAT和GCN。GCN我之前有一篇ISBI的论文用的便是这个,后来就没再看过了。嘎嘎
if self.sweep_type in [SweepType.GAT, SweepType.GCN]:
# 总归,图网络的特征提取,其实和transformer的attention map十分类似。这儿在微观讲述模型结构的时分,暂时先不细讲,之后在仔细的考虑TCN和GNN的代码完成细节。
if self.edge_weights:
# 这个带上edge-weights的概念,也便是会输入两个节点之间的衔接的强弱。
x = self.gnn_conv1(x, edge_index, edge_weight=edge_attr.view(-1))
else:
# 没有edgeweights的概念的,则是,只是告诉模型这两个节点有衔接有关系,可是并不会进一步的去诉说强弱
x = self.gnn_conv1(x, edge_index)
x = self.activation(x)
x = F.dropout(x, training=self.training)
# 看来这儿的图网络,也是一个十分浅层的,只有1层或许2层的网络。
if self.num_gnn_layers == 2:
if self.edge_weights:
x = self.gnn_conv2(x, edge_index, edge_weight=edge_attr.view(-1))
else:
x = self.gnn_conv2(x, edge_index)
x = self.activation(x)
x = F.dropout(x, training=self.training)
# 此外,作者还考虑了叫做PNANodeModel的特征提取器
elif self.sweep_type == SweepType.META_NODE:
x = self.meta_layer(x, edge_index, edge_attr)
# 此外,作者还考虑了叫做MetaLayer的特征提取器
elif self.sweep_type == SweepType.META_EDGE_NODE:
x, edge_attr, _ = self.meta_layer(x, edge_index, edge_attr)
# 这儿便是和上一章节解说的graph pool的方法,有均匀,相加和DiffPool
if self.pooling == PoolingStrategy.MEAN:
x = global_mean_pool(x, data.batch)
elif self.pooling == PoolingStrategy.ADD:
x = global_add_pool(x, data.batch)
elif self.pooling in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED]:
# 咱们还记得上一章遗留了一个问题,便是DiffPool只能处理稠密邻接矩阵,而咱们的是稀少的。所以转化的方法在这儿,可见下面的to_dense_ad部分
adj_tmp = pyg_utils.to_dense_adj(edge_index, data.batch, edge_attr=edge_attr)
if edge_attr is not None: # Because edge_attr only has 1 feature per edge
adj_tmp = adj_tmp[:, :, :, 0]
x_tmp, batch_mask = pyg_utils.to_dense_batch(x, data.batch)
# self.diff_pool便是DiffPool这个组件,下一末节继续细讲
x, link_loss, ent_loss = self.diff_pool(x_tmp, adj_tmp, batch_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.activation(self.pre_final_linear(x))
elif self.pooling == PoolingStrategy.CONCAT:
x, _ = to_dense_batch(x, data.batch)
x = x.view(-1, self.NODE_EMBED_SIZE * self.num_nodes)
x = self.activation(self.pre_final_linear(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.final_linear(x)
if self.final_sigmoid:
return torch.sigmoid(x) if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (
torch.sigmoid(x), link_loss, ent_loss)
else:
return x if self.pooling not in [PoolingStrategy.DIFFPOOL, PoolingStrategy.DP_MAX, PoolingStrategy.DP_ADD, PoolingStrategy.DP_MEAN, PoolingStrategy.DP_IMPROVED] else (x, link_loss, ent_loss)
关于上述代码段的弥补扩展:
- LSTM-弥补1
def init_lstm_hidden(x):
h0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
c0 = torch.zeros(run_cfg['tcn_depth'], x.size(0), run_cfg['tcn_hidden_units'])
return [t.to(x.device) for t in (h0, c0)]
- LSTM-弥补2
self.temporal_conv = nn.LSTM(input_size=1,
hidden_size=run_cfg['tcn_hidden_units'],
num_layers=run_cfg['tcn_depth'],
dropout=dropout_perc,
batch_first=True)
- CNN-弥补1
stride = 2
padding = 3
self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
self.conv1d_1 = nn.Conv1d(1, self.channels_conv, 7, padding=padding, stride=stride)
self.conv1d_2 = nn.Conv1d(self.channels_conv, self.channels_conv * 2, 7, padding=padding, stride=stride)
self.conv1d_3 = nn.Conv1d(self.channels_conv * 2, self.channels_conv * 4, 7, padding=padding, stride=stride)
self.conv1d_4 = nn.Conv1d(self.channels_conv * 4, self.channels_conv * 8, 7, padding=padding, stride=stride)
self.batch1 = BatchNorm1d(self.channels_conv)
self.batch2 = BatchNorm1d(self.channels_conv * 2)
self.batch3 = BatchNorm1d(self.channels_conv * 4)
self.batch4 = BatchNorm1d(self.channels_conv * 8)
self.temporal_conv = nn.Sequential(self.conv1d_1, self.activation, self.batch1, nn.Dropout(dropout_perc),
self.conv1d_2, self.activation, self.batch2, nn.Dropout(dropout_perc),
self.conv1d_3, self.activation, self.batch3, nn.Dropout(dropout_perc),
self.conv1d_4, self.activation, self.batch4, nn.Dropout(dropout_perc))
self.init_weights()
- TCN-弥补1
#self.size_before_lin_temporal = self.channels_conv * 8 * self.final_feature_size
#self.lin_temporal = nn.Linear(self.size_before_lin_temporal, self.NODE_EMBED_SIZE - self.multimodal_size)
if run_cfg['tcn_hidden_units'] == 8:
self.size_before_lin_temporal = self.channels_conv * (2 ** (run_cfg['tcn_depth'] - 1)) * self.num_time_length
else:
self.size_before_lin_temporal = run_cfg['tcn_hidden_units'] * self.num_time_length
self.lin_temporal = self._get_lin_temporal(run_cfg)
tcn_layers = []
for i in range(run_cfg['tcn_depth']):
if run_cfg['tcn_hidden_units'] == 8:
tcn_layers.append(self.channels_conv * (2 ** i) )
else:
tcn_layers.append(run_cfg['tcn_hidden_units'])
self.temporal_conv = TemporalConvNet(1,
tcn_layers,
kernel_size=run_cfg['tcn_kernel'],
dropout=self.dropout,
norm_strategy=run_cfg['tcn_norm_strategy'])
- _get_lin_temporal
def _get_lin_temporal(self, run_cfg):
if run_cfg['tcn_final_transform_layers'] == 1:
lin_temporal = nn.Linear(self.size_before_lin_temporal,
self.NODE_EMBED_SIZE - self.multimodal_size)
elif run_cfg['tcn_final_transform_layers'] == 2:
lin_temporal = nn.Sequential(
nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
self.activation, nn.Dropout(self.dropout),
nn.Linear(int(self.size_before_lin_temporal / 2), self.NODE_EMBED_SIZE - self.multimodal_size))
elif run_cfg['tcn_final_transform_layers'] == 3:
lin_temporal = nn.Sequential(
nn.Linear(self.size_before_lin_temporal, int(self.size_before_lin_temporal / 2)),
self.activation, nn.Dropout(self.dropout),
nn.Linear(int(self.size_before_lin_temporal / 2), int(self.size_before_lin_temporal / 3)),
self.activation, nn.Dropout(self.dropout),
nn.Linear(int(self.size_before_lin_temporal / 3), self.NODE_EMBED_SIZE - self.multimodal_size))
return lin_temporal
- to_dense_adj
import torch_geometric.utils as pyg_utils
pyg_utils.to_dense_adj
这个办法的目的是:Converts batched sparse adjacency matrices given by edge indices and edge attributes to a single dense batched adjacency matrix。
官方文档的介绍地址在:torch_geometric.utils.to_dense_adj — pytorch_geometric documentation (pytorch-geometric.readthedocs.io)
综上所述,便是时刻序列在这个模型当中经过的全部进程。先是对时刻序列进行编码,也便是抽取特征。抽取之后,挑选合适的图网络再此进行特征提取。最终使用DiffPool进行特征整合。