当前位置: 移动技术网 > IT编程>开发语言>C/C++ > 单层GNN完成cora数据集节点分类任务

单层GNN完成cora数据集节点分类任务

2020年07月16日  | 移动技术网IT编程  | 我要评论

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
#使用上一篇博客中用到数据集处理函数
from coraDatasetsProcess import main
node_nums,feature_dims,label_list,feat_Matrix,degree_list,cites,X_Node,X_Neis=main()

#定义图卷积模型
class gnnModel(torch.nn.Module):
    def __init__(self):
        super(gnnModel, self).__init__()
        self.lin1=nn.Linear(in_features=1433,out_features=7)
    def forward(self,x,dig_list,A):
        print(A.shape)
        N=len(x)
        I_list=[1 for i in range(N)]
        I=np.diag(I_list)#单位矩阵
        A=A+I+I#添加自循环
        diags=np.diag(dig_list**(-0.5))
        pre=np.dot(np.dot(np.dot(diags,A),diags),x)
        pre = pre.astype(np.float32)
        pre=torch.from_numpy(pre)#从numpy的ndarray格式转化为tensor张量
        x=self.lin1(pre)
        return x
#数据预处理
def processData():
    A=np.zeros((node_nums,node_nums))
    for i in range(len(X_Node)):
        source=X_Node[i].item()
        target=X_Neis[i].item()
        A[source][target]=1#构造邻接矩阵
    return A

#实例化模型并训练
def modelStart(A):
    net=gnnModel()#实例化图神经网络模型
    net.train()#训练模式

    #优化器
    optimizer=torch.optim.Adam(net.parameters(),lr=0.01)
    #损失函数
    loss_function=nn.CrossEntropyLoss()
    out=0
    lossList=[]
    for epoch in range(200):
        optimizer.zero_grad()
        out=net(feat_Matrix[:500],degree_list[:500],A[:500,:500])
        loss=loss_function(out,label_list[:500])
        lossList.append(loss)
        print("epoch:",epoch," loss:",loss)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        out = net(feat_Matrix, degree_list, A)
        max_value,max_index=torch.max(out.data,1)
        correct=max_index.eq(label_list).sum().item()
        print("the accuracy of node classification is:",correct/len(label_list))
    return lossList

#plot画loss曲线
def plotCora(lossList):
    N=len(lossList)
    y=lossList
    x=[i for i in range(N)]
    p=plt.plot(x,y)
    plt.xlabel("iterations")
    plt.ylabel("loss")
    plt.title("the loss curve for node classification")
    plt.show(p)
def mainP():
    A=processData()
    lossList=modelStart(A)
    plotCora(lossList)
if __name__ == '__main__':
    mainP()

 ,公式如图所示

使用前500个节点数据集做训练,用整个数据集做测试,最终经过200次迭代后,节点分类的准确率是75%

loss曲线如下:

本文地址:https://blog.csdn.net/just_so_so_fnc/article/details/107356309

如对本文有疑问, 点击进行留言回复!!

相关文章:

验证码:
移动技术网