前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >基于图卷积网络的节点分类

基于图卷积网络的节点分类

作者头像
裴来凡
发布2024-04-10 14:11:53
830
发布2024-04-10 14:11:53
举报

步骤1:下载并加载QM7数据

下载QM7数据集。

代码语言:javascript
复制
dataURL = "http://quantum-machine.org/data/qm7.mat";
outputFolder = fullfile(tempdir,"qm7Data");
dataFile = fullfile(outputFolder,"qm7.mat");

if ~exist(dataFile,"file")
    mkdir(outputFolder);
    disp("Downloading QM7 data...");
    websave(dataFile, dataURL);
    disp("Done.")
end

从MAT文件加载QM7数据。

代码语言:javascript
复制
data = load(dataFile)

data = 包含以下字段的 struct: X: [7165×23×23 single] R: [7165×23×3 single] Z: [7165×23 single] T: [-417.9600 -712.4200 -564.2100 -404.8800 -808.8700 -677.1600 -796.9800 -860.3300 -1.0085e+03 -861.7300 -708.3700 -725.9300 -879.3800 -618.7200 -871.1900 -653.4400 -1.0109e+03 -1.1594e+03 -1.0039e+03 -1.0184e+03 -1.0250e+03 … ] P: [5×1433 int64]

从加载的结构中提取库仑数据和原子序数。

代码语言:javascript
复制
coulombData = double(permute(data.X, [2 3 1]));
atomData = sort(data.Z,2,'descend');

查看第一次观测的原子。

代码语言:javascript
复制
atomData(1,:)

ans = 1×23 single 行向量 6 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

步骤2:预处理图形数据

将训练数据中的库仑矩阵转换为邻接矩阵。

代码语言:javascript
复制
adjacencyData = coulomb2Adjacency(coulombData,atomData);

将图中的前几个分子可视化。

代码语言:javascript
复制
figure
tiledlayout("flow")

for i = 1:9
    % 提取邻接矩阵
    atomicNumbers = nonzeros(atomData(i,:));
    numNodes = numel(atomicNumbers);
    A = adjacencyData(1:numNodes,1:numNodes,i);

    % 将邻接矩阵转换为图
    G = graph(A);

    % 将原子序数转换为符号
    symbols = atomicSymbol(atomicNumbers);

    % 绘图
    nexttile
    plot(G,NodeLabel=symbols,Layout="force")
    title("Molecule " + i)
end

使用直方图可视化每个标签类别的频率。

代码语言:javascript
复制
figure
histogram(categorical(atomicSymbol(atomData)))
xlabel("Node Label")
ylabel("Frequency")
title("Label Counts")

将数据划分为训练、验证和测试分区,分别包含80%、10%和10%的数据。

代码语言:javascript
复制
numObservations = size(adjacencyData,3);
[idxTrain,idxValidation,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]);

adjacencyDataTrain = adjacencyData(:,:,idxTrain);
adjacencyDataValidation = adjacencyData(:,:,idxValidation);
adjacencyDataTest = adjacencyData(:,:,idxTest);

coulombDataTrain = coulombData(:,:,idxTrain);
coulombDataValidation = coulombData(:,:,idxValidation);
coulombDataTest = coulombData(:,:,idxTest);

atomDataTrain = atomData(idxTrain,:);
atomDataValidation = atomData(idxValidation,:);
atomDataTest = atomData(idxTest,:);

预处理训练和验证数据。

代码语言:javascript
复制
[ATrain,XTrain,labelsTrain] = preprocessData(adjacencyDataTrain,coulombDataTrain,atomDataTrain);
size(XTrain)

ans = 88456 1

代码语言:javascript
复制
size(labelsTrain)

ans =

88456 1

代码语言:javascript
复制
[AValidation,XValidation,labelsValidation] = preprocessData(adjacencyDataValidation,coulombDataValidation,atomDataValidation);

使用训练特征的均值和方差对特征进行归一化。

代码语言:javascript
复制
muX = mean(XTrain);
sigsqX = var(XTrain,1);

XTrain = (XTrain - muX)./sqrt(sigsqX);
XValidation = (XValidation - muX)./sqrt(sigsqX);

步骤3:定义深度学习模型

(1)初始化模型参数

创建包含模型参数的结构参数。

代码语言:javascript
复制
parameters = struct;

初始化第一次乘法运算的权重。

代码语言:javascript
复制
numHiddenFeatureMaps = 32;
numInputFeatures = size(XTrain,2);

sz = [numInputFeatures numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numInputFeatures;
parameters.mult1.Weights = initializeGlorot(sz,numOut,numIn,"double");

初始化第二次乘法运算的权重。

代码语言:javascript
复制
sz = [numHiddenFeatureMaps numHiddenFeatureMaps];
numOut = numHiddenFeatureMaps;
numIn = numHiddenFeatureMaps;
parameters.mult2.Weights = initializeGlorot(sz,numOut,numIn,"double");

初始化第三次乘法运算的权重。

代码语言:javascript
复制
classes = categories(labelsTrain);
numClasses = numel(classes);

sz = [numHiddenFeatureMaps numClasses];
numOut = numClasses;
numIn = numHiddenFeatureMaps;
parameters.mult3.Weights = initializeGlorot(sz,numOut,numIn,"double");

查看参数结构。

代码语言:javascript
复制
parameters

parameters = 包含以下字段的 struct: mult1: [1×1 struct] mult2: [1×1 struct] mult3: [1×1 struct]

查看第一次乘法运算的参数。

代码语言:javascript
复制
parameters.mult1

ans = 包含以下字段的 struct: Weights: [1×32 dlarray]

(2)定义函数模型

定义模型损失函数,指定训练选项。训练1500次,并将Adam优化器的学习率设置为0.01。

代码语言:javascript
复制
numEpochs = 1500;
learnRate = 0.01;

每隔300次验证一次网络。

代码语言:javascript
复制
validationFrequency = 300;

(3)训练模型

初始化Adam的参数。

代码语言:javascript
复制
trailingAvg = [];
trailingAvgSq = [];

将训练和验证功能数据转换为dlarray对象。

代码语言:javascript
复制
XTrain = dlarray(XTrain);
XValidation = dlarray(XValidation);

如果GPU可用,则在GPU上进行训练。

代码语言:javascript
复制
if canUseGPU
XTrain = gpuArray(XTrain);
end

使用onehotencode函数将训练和验证标签转换为一个热编码向量。

代码语言:javascript
复制
TTrain = onehotencode(labelsTrain,2,ClassNames=classes);
TValidation = onehotencode(labelsValidation,2,ClassNames=classes);

初始化TrainingProgressMonitor对象。

代码语言:javascript
复制
monitor = trainingProgressMonitor( ...
    Metrics=["TrainingLoss","ValidationLoss"], ...
    Info="Epoch", ...
    XLabel="Epoch");

groupSubPlot(monitor,"Loss",["TrainingLoss","ValidationLoss"])

使用自定义训练循环训练模型。

代码语言:javascript
复制
epoch = 0;

while epoch < numEpochs && ~monitor.Stop
epoch = epoch + 1;

    % 评估模型损失函数和梯度下降
    [loss,gradients] = dlfeval(@modelLoss,parameters,XTrain,ATrain,TTrain);

    % 使用Adam优化器更新网络参数
    [parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
trailingAvg,trailingAvgSq,epoch,learnRate);

    % 训练损失和迭代次数
recordMetrics(monitor,epoch,TrainingLoss=loss);
updateInfo(monitor,Epoch=(epoch+" of "+numEpochs));

    % 验证指标
if epoch == 1 || mod(epoch,validationFrequency) == 0
YValidation = model(parameters,XValidation,AValidation);
lossValidation = crossentropy(YValidation,TValidation,DataFormat="BC");

    % 验证损失
recordMetrics(monitor,epoch,ValidationLoss=lossValidation);
end

monitor.Progress = 100*(epoch/numEpochs);
end

(4)测试模型

使用测试数据测试模型。

代码语言:javascript
复制
[ATest,XTest,labelsTest] = preprocessData(adjacencyDataTest,coulombDataTest,atomDataTest);
XTest = (XTest - muX)./sqrt(sigsqX);

将测试功能数据转换为dlarray对象。

代码语言:javascript
复制
XTest = dlarray(XTest);

对数据进行预测,并使用onehotdecode函数将概率转换为分类标签。

代码语言:javascript
复制
YTest = model(parameters,XTest,ATest);
YTest = onehotdecode(YTest,classes,2);

计算精度。

代码语言:javascript
复制
accuracy = mean(YTest == labelsTest)

accuracy = 0.9001

使用混淆图函数计算混淆矩阵。

代码语言:javascript
复制
figure
cm = confusionchart(labelsTest,YTest, ...
    ColumnSummary="column-normalized", ...
    RowSummary="row-normalized");
title("GCN QM7 Confusion Chart");

(5)使用新数据进行预测

对未标记的数据进行一些预测。

代码语言:javascript
复制
numObservationsNew = 4;
adjacencyDataNew = adjacencyDataTest(:,:,1:numObservationsNew);
coulombDataNew = coulombDataTest(:,:,1:numObservationsNew);

predictions = modelPredictions(parameters,coulombDataNew,adjacencyDataNew,muX,sigsqX,classes);

在绘图中可视化预测。

代码语言:javascript
复制
figure
tiledlayout("flow")

for i = 1:numObservationsNew
    % 提取未添加的数据
    numNodes = find(any(adjacencyDataTest(:,:,i)),1,"last");

    A = adjacencyDataTest(1:numNodes,1:numNodes,i);

    % 绘图
    nexttile
    G = graph(A);
    plot(G,NodeLabel=string(predictions{i}),Layout="force")
    title("Observation " + i + " Prediction")
end

(6)支持功能

预处理数据功能。

代码语言:javascript
复制
function [adjacency,features,labels] = preprocessData(adjacencyData,coulombData,atomData)

[adjacency, features] = preprocessPredictors(adjacencyData,coulombData);
labels = [];

% 将标签转换为类别标签
for i = 1:size(adjacencyData,3)
    % 提取并添加未添加的数据
    T = nonzeros(atomData(i,:));
    labels = [labels; T];
end

labels2 = nonzeros(atomData);
assert(isequal(labels2,labels2))

atomicNumbers = unique(labels);
atomNames =  atomicSymbol(atomicNumbers);
labels = categorical(labels, atomicNumbers, atomNames);

end

预处理预测器函数。

代码语言:javascript
复制
function [adjacency,features] = preprocessPredictors(adjacencyData,coulombData)

adjacency = sparse([]);
features = [];

for i = 1:size(adjacencyData, 3)
% 提取未添加的数据
numNodes = find(any(adjacencyData(:,:,i)),1,"last");

A = adjacencyData(1:numNodes,1:numNodes,i);
X = coulombData(1:numNodes,1:numNodes,i);

% 从库仑矩阵的对角线中提取特征向量
X = diag(X);

% 添加提取的数据
adjacency = blkdiag(adjacency,A);
features = [features; X];
end

end

函数模型。

代码语言:javascript
复制
function Y = model(parameters,X,A)

ANorm = normalizeAdjacency(A);

Z1 = X;

Z2 = ANorm * Z1 * parameters.mult1.Weights;
Z2 = relu(Z2) + Z1;

Z3 = ANorm * Z2 * parameters.mult2.Weights;
Z3 = relu(Z3) + Z2;

Z4 = ANorm * Z3 * parameters.mult3.Weights;
Y = softmax(Z4,DataFormat="BC");

end

模型损失函数。

代码语言:javascript
复制
function [loss,gradients] = modelLoss(parameters,X,A,T)

Y = model(parameters,X,A);
loss = crossentropy(Y,T,DataFormat="BC");
gradients = dlgradient(loss, parameters);

end

模型预测函数。

代码语言:javascript
复制
function predictions = modelPredictions(parameters,coulombData,adjacencyData,mu,sigsq,classes)

predictions = {};
numObservations = size(coulombData,3);

for i = 1:numObservations
% 提取未添加的数据
numNodes = find(any(adjacencyData(:,:,i)),1,"last");
A = adjacencyData(1:numNodes,1:numNodes,i);
X = coulombData(1:numNodes,1:numNodes,i);

% 预处理数据
[A,X] = preprocessPredictors(A,X);
X = (X - mu)./sqrt(sigsq);
X = dlarray(X);

% 预测
Y = model(parameters,X,A);
Y = onehotdecode(Y,classes,2);
predictions{end+1} = Y;
end

end

归一化邻接矩阵函数。

代码语言:javascript
复制
function ANorm = normalizeAdjacency(A)

% 将自连接添加到邻接矩阵
A = A + speye(size(A));

% 计算度数的平方倒数
degree = sum(A, 2);
degreeInvSqrt = sparse(sqrt(1./degree));

% 归一化邻接矩阵
ANorm = diag(degreeInvSqrt) * A * diag(degreeInvSqrt);

end

链接:http://quantum-machine.org/data/qm7.mat

https://ww2.mathworks.cn/help/deeplearning/ug/node-classification-using-graph-convolutional-network.html?s_tid=srchtitle_site_search_10_LORENZ

本文参与 腾讯云自媒体分享计划,分享自微信公众号。
原始发表:2024-04-01,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 图像处理与模式识别研究所 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体分享计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
相关产品与服务
腾讯云服务器利旧
云服务器(Cloud Virtual Machine,CVM)提供安全可靠的弹性计算服务。 您可以实时扩展或缩减计算资源,适应变化的业务需求,并只需按实际使用的资源计费。使用 CVM 可以极大降低您的软硬件采购成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档