%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Let's compare PSISCV, ISCV, and WAIC in influential observation
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Pareto Smoothed Importance Sampling Cross Validation
%%% or Vehtari-Gelman Importance Sampling Cross Validation
%%% is compared with ISCV and WAIC.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Program was made by Sumio Watanabe
%%% ----------------------------------------------------------
%%% PPSISCV paper:
%%% Efficient implementation of leave-one-out cross-validation
%%% and WAIC for evaluating fitted Bayesian models.
%%% by Aki Vehtari, Andrew Gelman, Jonah Gabry
%%% http://arxiv.org/abs/1507.04544
%%% ----------------------------------------------------------
%%%
%%% This program calculates PSISCV proposed by Vehtari, Gelman, and Gabry.
%%%
%%% Input X: p(x)=(1-1/n) N(0,1^2) + (1/n) N(0,Leverage^2)
%%%
%%% Regression : Y = ax + N(0,1/s0)
%%%
%%% If Leverage=1, then no leverage sample is contained.
%%% If Leverage=10, then n-th sample is a leverage sample.
%%% If Leverage=100, then n-th sample is a very leverage sample.
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
clear;
close all hidden;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
n=30; %%% training sample number
ntest=50*n; %%% test sample number
Leverage=10; %%% Influential observation
%%%%%%%%%%%%%%%%%%% One sample is a leverage sample.
a0=0.1; %%% true parameter
s0=100; %%% true parameter
mu=0.01; %%% hyperparameter
KKK=2000; %%% posteior sample number
CYCLE=1000; %%% Independent Traial number
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Inputs of Training samples:
%%% (n-1) samples are taken from N(0,1^2).
%%% Only one sample is taken from N(0,Leverage^2).
%%% Inputs of Test samples:
%%% ntest*(1-1/n) samples are taken from N(0,1^2).
%%% ntest/n samples are taken from N(0,Leverage^2).
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Model and Prior
%%% p(y|x,a,s)=(s/(2pi))^{0.5}exp(-(s/2)(y-ax)^2)
%%% phi(a,s)=s exp( -(mu/2)s*(1+a^2) )
%%%%%%%%%%% statistical model
prob=@(x,y,a,s)( (s/(2*pi)).^0.5.*exp(-(s/2).*(y-a*x).^2) );
rng(1);
%%%%%%%%%%%%%%%%% Test sample generation for GE
xtest=randn(1,ntest);
xtest(1:1:ntest/n)=Leverage*randn(1,ntest/n);
ytest=a0*xtest+(1/s0)^0.5*randn(1,ntest);
Stest=-mean(log(prob(xtest,ytest,a0,s0)));
%%%%%%%%%%%%%%%%%%%%%
for kur=1:1:CYCLE
%%%%%%%%%%%%%%%%% Training sample generation
x=randn(1,n);
x(n)=Leverage*randn(1,1);
y=a0*x+(1/s0)^0.5*randn(1,n);
Sn=-mean(log(prob(x,y,a0,s0)));
%%%%%%%%%%%%%%%%%%%% constant values
cc1=mu+sum(x.^2);
mm1=sum(x.*y)/cc1;
aa1=0.5*(mu+sum(y.^2)-mm1^2*cc1);
%%%%%%%%%%% Posterior Sampling %%%%%%%%%%%%%%%%
sp=gamrnd((n+3)/2,1/aa1,1,KKK);
ap=mm1+(1./(sp*cc1).^0.5).*randn(1,KKK);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% Remark: Gamma distribution in Maltab ---------
%%% X=gamrnd(a,b,1,n);
%%% gamrnd(a,b) --- f(x)=x^{a-1}exp(-x/b)/(b^a*Gamma(a))
%%% ----------------------------------------------
for i=1:1:n
fff=prob(x(i),y(i),ap,sp);
waic_each(i)=-log(mean(fff))+var(log(fff));
iscv_each(i)=log(mean(1./fff));
%%%%% hhh is the improved weight by Vehtari-Gelman-Gabry. %%%%%%%%%
invf=sort(1./fff); %%% original importance weights are sorted.
uu=invf(0.8*KKK); %%% 0.8KKK location
ggg=invf(0.8*KKK+1:1:KKK)-uu; %%% tail weights in [0,infty)
para=gpfit(ggg); %%% statistical estimation of general Pareto dist.
yyy=1/(0.4*KKK):1/(0.2*KKK):(1-1/(0.4*KKK)); %%% Uniform in [0,1]
fff2=invf; %%% Improved importance weights
fff2(0.8*KKK+1:1:KKK)=uu+(para(2)/para(1))*((1-yyy).^(-para(1))-1); %%% replaced by estimation
trunc=KKK^0.75*mean(fff2); %%% VGG upperbound.
hhh=-(abs(trunc-fff2)-trunc-fff2)/2; %%% truncation of weights
psis_each(i)=-log(mean(hhh./invf)/mean(hhh)); %%% PSISCV
end
for j=1:1:ntest
ftest=prob(xtest(j),ytest(j),ap,sp);
ge_each(j)=-log(mean(ftest));
end
WAIC(kur)=mean(waic_each)-Sn;
ISCV(kur)=mean(iscv_each)-Sn;
PSIS(kur)=mean(psis_each)-Sn;
GE(kur)=mean(ge_each)-Stest;
%%% fprintf('[%2g]WA=%f, CV=%f, PSIS=%f, GE=%f\n',kur,WAIC(kur),ISCV(kur),PSIS(kur),GE(kur));
if(mod(kur,10)==0)
fprintf('%g\n',kur);
end
end
fprintf('Leverage=%f\n',Leverage);
fprintf('WAIC(mean,std) = %f,%f\n',mean(WAIC),std(WAIC));
fprintf('ISCV(mean,std) = %f,%f\n',mean(ISCV),std(ISCV));
fprintf('PSIS(mean,std) = %f,%f\n',mean(PSIS),std(PSIS));
fprintf('GEN (mean,std) = %f,%f\n',mean(GE),std(GE));
fprintf('E|GE-WAIC|=%f, E|GE-ISCV|=%f, E|GE-PSIS|=%f\n',...
mean(abs(GE-WAIC)),mean(abs(ISCV-GE)),mean(abs(PSIS-GE)));
fprintf('E(GE-WAIC)^2=%f, E(GE-ISCV)^2=%f, E(GE-PSIS)^2=%f\n',...
mean((GE-WAIC).^2),mean((ISCV-GE).^2),mean((PSIS-GE).^2));
count=0;
for kur=1:1:CYCLE
if(abs(PSIS(kur)-GE(kur))