首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >深入研究一个简单的感知器模式和神经网络。

深入研究一个简单的感知器模式和神经网络。

作者头像
裴来凡
发布2022-05-28 15:13:29
发布2022-05-28 15:13:29
2910
举报

batchPerceptron.m

代码语言:javascript
复制
function [w, y, error] = batchPerceptron(x, t, epochs, r)

[m , n] = size(x);

w = rand(n + 1, 1) * 2 - 1;
y = zeros(m,1);

for e = 1:epochs
    
    p = randperm(m);
    x = x(p,:);
    t = t(p);
    s = zeros(1,n + 1);
    
    for i = 1:m
        xi = [x(i,:) 1]; 
        out = xi * w;
        if out > 0
            y(i) = 1;
        else
            y(i) = -1;
        end
        
        if t(i) ~= y(i)
            s = s + t(i) * xi;
        end
    end
    
    w = w + r * s';
    
    %calculate the minimum squared error at every step
    %mse((y-t) * 0.5);
end

error = nnz(gsubtract(y,t)) / m

end

plotSet.m

代码语言:javascript
复制
function [] = plotSet(x,t,w)

hold on

plot(x(find(t==1),1),x(find(t==1),2),'b+');
plot(x(find(t==-1),1),x(find(t==-1),2),'ro');

x1 = -1:1:1;
x2 = - (w(3) + x1 * w(1)) / w(2);
plot(x1,x2,'g');
    
hold off

end

run.m

代码语言:javascript
复制
clear;
load x.mat;
load t.mat;

[w, y, err] = batchPerceptron(x, t, 1000, 0.5);
clf;
plotSet(x, t, w);

trainPerceptron.m

代码语言:javascript
复制
function [w, y, error] = trainPerceptron(x, t, epochs)

[m , n] = size(x);

w = rand(n + 1, 1) * 2 - 1;
y = zeros(m,1);

for e = 1:epochs
    
    p = randperm(m);
    x = x(p,:);
    t = t(p);
    
    for i = 1:m
        xi = [x(i,:) 1]; 
        out = xi * w;
        if out > 0
            y(i) = 1;
        else
            y(i) = -1;
        end
        
        if t(i) ~= y(i)
            % online learning
            w = w + t(i) * xi';
        end
    end
end

error = nnz(gsubtract(y,t)) / m

end

run.m

代码语言:javascript
复制
clear;

load t.mat
load x.mat

hold on;

plot(x(t==1,1),x(t==1,2),'b+');
plot(x(t==-1,1),x(t==-1,2),'ro');

[w, ~, ~] = trainPerceptron(x, t, 100);

x1 = -1:1:1;
x2 = - (w(3) + x1 * w(1)) / w(2);
plot(x1,x2,'g');
    
hold off;

run.m

代码语言:javascript
复制


%% LOAD DATASET
clear;
warning('off','all');
fileID = fopen('tic-tac-toe.data');
board = textscan(fileID,'%s\n');

%% TRANSFORM DATASET TO A VALID ONE

[P, T] = board2mat(board);
clearvars -except P T;

%% FINDING THE BEST ALGORITHM
% trainings = {'traingd' 'traingdm' 'traincgf' 'traincgp' 'traingda' 'trainrp' 'traingdm' 'traingdx' 'trainbfg' 'trainoss' 'trainlm'};
% colors = {'red', 'green', 'blue', 'yellow', [0.8 0.4 0.6], [0.5 0.5 0.5], 'magenta', 'cyan', 'black', [0.1 0.6 1], [0.3 0.6 0.3]};
%                                            %a dark pink      %grey                              %kind of dark blue  %dark green
% [R,V] = size(P); 
% [Q,~] = size(T);
% 
% %clf;
% for i = 1:length(trainings)
%     net = newff(minmax(P),[R Q],{'tansig' 'logsig'}, trainings{i});
%     net.performFcn = 'sse';
%     
%     %% Method I - efficiency of every algorithm in a number of given epochs
% %     net.trainParam.epochs = 100;
% %     [net, tr] = train(net, P, T);
% %     hold on;
% %     plot(tr.epoch, tr.perf, 'Color', colors{i});
% %     hold off;
%     
%     %% Method II - execution time of every algorithm until reaches the goal
%     net.trainParam.epochs = 10000;
%     net.trainParam.goal = 0.1;
%     [net, tr] = train(net, P, T);
%     %fprintf('Time %s: %s; SSE: %g\n', trainings{i}, datestr(tr.time(end) / 86400, 'MM:SS.FFF'), tr.perf(end));
%       % OR
%     fprintf('Epoci %s: %g; SSE: %g\n', trainings{i}, tr.epoch(end), tr.perf(end));
% end



%% TRAIN NETWORK (with Levenberg-Marquardt)
p = randperm(size(P,2));
P = P(:,p);
T = T(:,p);
[R, Q] = size(P);
E = [];

idx_testing = 1:2:Q;
idx_training = 2:2:Q;

training.P = P(:, idx_training);
training.T = T(:, idx_training);
testing.P = P(:, idx_testing);
testing.T = T(:, idx_testing);

net = newff(minmax(P), [10 5 size(T,1)], {'tansig' 'logsig' 'logsig'}, 'trainlm');
net.performFcn = 'sse';
net.trainParam.epochs = 100;
net.trainParam.mu = 10;

%train first part
net = train(net, training.P, training.T);
%test second part
Y = sim(net, testing.P);
Y = round(Y);
E = numel(find((testing.T ~= Y) > 0)) / length(Y)

%train hole datates
net = train(net, P, T);
% Y = sim(net, P);
% Y = round(Y);
% E = [E numel(find((T ~= Y) > 0)) / length(Y)];

% mean(E)
board = zeros(3);
board( round( (9-1) * rand() + 1) ) = 1; % begin X
exit = 0;
mat2board(board)

%% PLAY X AND O

while 1
    %% READ O POSITION
    
     while 1 % read until a valid position
         result = input('Write O position:\n', 's'); 
         if strcmp(result,'stop') == 1
             exit = 1;
             break;
         end
         if isstrprop(result, 'digit')       
             poz = str2double(result); % extract position
             if poz < 1 || poz > 9
                 continue;
             end
             if board(poz) == 1 || board(poz) == -1
                 continue;
             end
             break;
         end
     end
     if exit == 1
         break;
     end
     
     board(poz) = -1; % set O there
     
     %% find if O wins
     y = round( sim(net,board(:)) );
     
     if y == 0 % IF the network tells it lost
         for i = 1:3
             % verify it's correct
             if sum(board(:,i) == [-1; -1 ; -1]) > 2 ||...
                     sum(board(i,:) == [-1 -1 -1]) > 2 ||...
                     sum([board(1) board(5) board(9)] == [-1 -1 -1]) > 2 ||...
                     sum([board(7) board(5) board(3)] == [-1 -1 -1]) > 2
                 fprintf('\n*** O wins ***\n');
                 exit = 1;
                 break;
             end
         end
         if exit == 1
             mat2board(board)
             break;
         end 
     end
     
     %% Find best position for putting X
     max = 0; I = 0; y = 0; exit = 0;
     for i = 1:numel(board) % for every square
         if board(i) == 0 % if it is free
             board(i) = 1; % put X there

             y = sim(net,board(:)); % test the resulted table
             
             if y > max
                 max = y;
                 I = i;
             end

             if round(y) == 1 % IF the network tells it wins
                 for j = 1:3 
                     % verify it actually won
                     if sum(board(:,j) == [1; 1 ; 1]) > 2 ||...
                             sum(board(j,:) == [1 1 1]) > 2 ||...
                             sum([board(1) board(5) board(9)] == [1 1 1]) > 2 ||...
                             sum([board(7) board(5) board(3)] == [1 1 1]) > 2
                         fprintf('\n*** X wins ***\n');
                         exit = 1;
                         break;
                     end
                 end
                 if exit == 1
                     break;
                 end
             end
             if exit == 1
                 break;
             end

             board(i) = 0;
         end
     end
     if exit == 1
         board(i) = 1; % put X there
         mat2board(board)
         break;
     end
     
     if max == 0 % IF a suitable board(table) wasn't found
         I = round( (9-1) * rand() + 1); % find a free position randomly
         while board(I) ~= 0
             I = 9-(-1) * rand() + (-1);
         end
     end
     
     board(I) = 1; % put X there
     
     %% Show Result
     mat2board(board) 
end
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2020-05-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 图像处理与模式识别研究所 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档