%%% Variational Bayes of Normal Mixture %%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
clf;
%%close all
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
K0=3; %%% True clusters
STDTRUE=0.3; %%% True Standard deviation of each clusters
K=3; %%% Components of learning clusters
STD=0.3; %%% 0.1 0.2 0.3 0.4 0.5 %%% Standard deviation in learning machine
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
NNN=100; %%% Number of samples
KURIKAESHI=100; %%% Number of recursive process
PRIORSIG=0.01; %%% 1/PRIORSIG = Variance of Prior
PHI0=0.5; %%% Hyperparameter of mixture ratio : 3/2 Kazuho's critical point
%%%%%%%%%%%%%%%%%%%%%%% True mixture ratios %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
KP1=0.2;
KP2=0.3;
KP3=1-KP1-KP2;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% make samples %%%%%%%%%%%%%%%%%%%%%%%%%%%
truecase=1;
if(truecase==1)
X0=[0, 0, 1;0, 1, 1];
end
if(truecase==2)
X0=[0, 0.0, 0.5;0, 0.5, 0.0];
end
if(truecase==3)
X0=[0.5, 0.5, 0.5;0.5, 0.5, 0.5];
end
YP=rand(1,NNN);
Y0=zeros(1,NNN);
for i=1:1:NNN
if(YP(i)>KP1+KP2)
Y0(i)=3;
else
if(YP(i)>=KP1)
Y0(i)=2;
else
Y0(i)=1;
end
end
end
XX=STDTRUE*randn(2,NNN);
for i=1:1:NNN
XX(:,i)=XX(:,i)+X0(:,Y0(i));
end
%%%%%%%%%%%%%%%%%%%%%%%% make data end %%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
digamma=@(x)(log(x)-0.5./x-1./(12*x.*x));
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%% Initialize VB
PHI=NNN/K*ones(1,K);
ETA0=NNN/K*ones(1,K);
ETA1=NNN/K*(mean(XX(1,:))+0.1*randn(1,K));
ETA2=NNN/K*(mean(XX(2,:))+0.1*randn(1,K));
YYY=zeros(K,NNN);
MR=zeros(1,K);
%%%%%%%%%% Recursive VB Start
for kuri=1:1:KURIKAESHI
for i=1:1:NNN
DD1=ETA1./ETA0-XX(1,i);
DD2=ETA2./ETA0-XX(2,i);
DDD=digamma(PHI)-digamma(NNN+3*PHI0)-1./ETA0-(DD1.*DD1+DD2.*DD2)/(2*STD*STD);
YYY(:,i)=exp(DDD-max(DDD))/sum(exp(DDD-max(DDD)));
end
for k=1:1:K
PHI(k)=PHI0+sum(YYY(k,:));
ETA0(k)=PRIORSIG+sum(YYY(k,:));
ETA1(k)=sum(YYY(k,:).*XX(1,:));
ETA2(k)=sum(YYY(k,:).*XX(2,:));
end
end
%%%%%%%%%%%%%%%%%Free Energy
FF1=-sum(gammaln(PHI(:)));
FF2=sum(log(ETA0)-(ETA1.*ETA1+ETA2.*ETA2)./(2*STD*STD*ETA0));
FF3=sum((XX(1,:).*XX(1,:)+XX(2,:).*XX(2,:))/(2*STD*STD)+log(STD*STD));
SSS = -sum(sum(YYY.*log(YYY)));
FreeEnergy = FF1+FF2+FF3+SSS;
%%%%%%%%%%%%%%%%%
fprintf('Free Energy=%.2f, Mixture Ratio=(',FreeEnergy);
for j=1:1:K
MR(j)=ETA0(j)/(NNN+PRIORSIG*K);
Y01(j)=ETA1(j)/ETA0(j);
Y02(j)=ETA2(j)/ETA0(j);
fprintf('%.2f ',MR(j));
end
fprintf(')\n');
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
probgauss=@(x,y,a,b,VA2)(exp(-((x-a).^2+(y-b).^2)/VA2)/sqrt(2*pi*VA2));
%%%%%%%%%%%%%%%%%%%%%%% plot samples %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
subplot(2,2,1);
plot(XX(1,:),XX(2,:),'bo'); hold on
title('Samples'); hold on
xlim([-1,2]);
ylim([-1,2]);
%%%%%%%%%%%%%%%%%%%% plot true and estimated %%%%%%%%%%%%%
subplot(2,2,2);
for j=1:1:K0
plot(X0(1,j),X0(2,j),'rs'); hold on
end
for j=1:1:K
plot(Y01(j),Y02(j),'b+'); hold on
end
title('True:Red Squares, Estimated:Brue +'); hold on
xlim([-1,2]);
ylim([-1,2]);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
subplot(2,2,3);
va2=2*STDTRUE*STDTRUE;
[x1,y1]=meshgrid(-1:0.1:2,-1:0.1:2);
zzz=KP1*probgauss(x1,y1,X0(1,1),X0(2,1),va2)+KP2*probgauss(x1,y1,X0(1,2),X0(2,2),va2)...
+KP3*probgauss(x1,y1,X0(1,3),X0(2,3),va2);
mesh(x1,y1,zzz);
zlim([0,1]);
title('True Probability Density Function'); hold on
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
subplot(2,2,4);
va2=2*STD*STD;
[x1,y1]=meshgrid(-1:0.1:2,-1:0.1:2);
zzz=0*x1;
for kk=1:1:K
zzz=zzz+ MR(kk)*probgauss(x1,y1,Y01(kk),Y02(kk),va2);
end
mesh(x1,y1,zzz);
zlim([0,1]);
title('Estimated Probability Density Function'); hold on
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%