function select_lines
R = 0.5;
T = 5000;
N = 500;
K = 15;
s = zeros(1,N);
cmi = zeros(K,N);
nu = zeros(K,1);
smax = zeros(K,1);
mi = zeros(K,1);
RandStream.setDefaultStream(RandStream('mt19937ar','seed',sum(100*clock)));
x = generate_points(T);
y = generate_labels(x, R);
fl = generate_line_features(N);
scr = get(0,'ScreenSize');
scrX = scr(3)/3; scrY = scr(4)/2; scrW = scrX; scrH = scrY-50*2;
f1 = figure(1); set(f1,'Position', [2*scrX scrY scrW scrH]);
show_points(x,y); show_lines(fl);
f2 = figure(2); set(f2,'Position', [2*scrX 0 scrW scrH]);
X = build_matrix(x,fl);
Hy = H1(y);
for n=1:N
s(n) = mut_inf(y,X(:,n));
end
for k=1:K
[smax(k), nu(k)] = max(s);
disp_scores(s, k, nu, smax(k));
mi(k) = calc_mi(X, nu, k, y, Hy, T); show_cmi(smax, mi, Hy);
color = calc_color(k, smax);
tbest = fl(1,nu(k)); dbest = fl(2,nu(k));
figure(1); draw_line_tri(tbest,dbest, color,2, -1,1,-1,1);
for n=1:N
cmi(k,n) = cond_mut_inf(y, X(:,n), X(:,nu(k)));
s(n) = min( s(n), cmi(k,n) );
end
disp_cmi(k, cmi);
end
fprintf('\nPress any key to continue...\n');
pause; close all; clear all;
function x=generate_points(T)
x = 2*rand(T,2) - 1;
function y=generate_labels(x, radius)
y = (sqrt(x(:,1).^2+x(:,2).^2) <= radius);
function f=generate_line_features(N)
t = pi*rand(1,N);
d = 2^0.5*(2*rand(1,N)-1);
s = sign((2*rand(1,N)-1));
f = [t;d;s];
function show_points(x, y)
plot(x(y==1,1), x(y==1,2),'g.', x(y==0,1),x(y==0,2),'k.');
axis equal; axis([-1 1 -1 1]); legend('y=1','y=0');
function show_lines(f)
for i=1:size(f,2)
t = f(1,i); d = f(2,i);
draw_line_tri(t,d, 'b',1, -1,1,-1,1);
end
function show_cmi(smax, mi, Hy)
K = size(smax,1);
xg = linspace(0,K+0.5,10);
Hyg = linspace(Hy,Hy,10);
figure(2);
subplot(2,1,1); bar(smax,'r');
xlabel('k'); ylabel('max_{Xn} min_{1<=k<=K} I(Y;Xn|X_{nu(k)})');
title('Conditional Mutual Information of each best feature Xk|1<=k<=K');
subplot(2,1,2); bar(mi,'b');
xlabel('k'); ylabel('I(Y;{Xk}1<=k<=K)');
title('Global Mutual Information I(Y;{Xk}1<=k<=K) evolution');
hold on; plot(xg, Hyg,'r'); hold off;
function X = build_matrix(x,f)
T = size(x,1);
N = size(f,2);
X = false(T,N);
for i=1:T
xp = x(i,1); yp = x(i,2);
for j=1:N
tf = f(1,j); df = f(2,j); sf = f(3,j);
X(i,j) = (sign(-xp*sin(tf) + yp*cos(tf) - df) == sf);
end
end
function disp_scores(s, k, nu, smax)
fprintf(1,'---------');
for n=1:size(s,2);
fprintf(1,'------');
end
fprintf(1,'\n');
fprintf(1,'s [%d] = ', k);
for n=1:size(s,2);
fprintf(1,'%0.3f ', s(n));
end
fprintf(1,'\n');
fprintf(1,'-> max s(n) = %0.3f - nu(%d) = %d\n', smax, k, nu(k));
function disp_cmi(k, cmi)
fprintf(1,'cmi[%d] = ', k);
for n=1:size(cmi,2);
fprintf(1,'%0.3f ', cmi(k,n));
end
fprintf(1,'\n');
function card = card1(x,u)
card = sum(x==u);
function card = card2(x,y,u,v)
card = sum(x==u & y==v);
function card = card3(x,y,z,u,v,w)
card = sum(x==u & y==v & z==w);
function xi = xi(x,T)
if x==0, xi = 0;
else xi = x*log2(x)/T;
end
function H = H1(Y)
T = size(Y,1);
H = log2(T) - ( xi(card1(Y,0),T) + xi(card1(Y,1),T) );
function H = H2(Y,Xn)
T = size(Y,1);
H = log2(T) - ( xi(card2(Y,Xn,0,0),T) + xi(card2(Y,Xn,0,1),T) +...
xi(card2(Y,Xn,1,0),T) + xi(card2(Y,Xn,1,1),T) );
function H = H3(Y,Xn,Xm)
T = size(Y,1);
H = log2(T) - (xi(card3(Y,Xn,Xm,0,0,0),T) + xi(card3(Y,Xn,Xm,0,0,1),T) +...
xi(card3(Y,Xn,Xm,0,1,0),T) + xi(card3(Y,Xn,Xm,0,1,1),T) +...
xi(card3(Y,Xn,Xm,1,0,0),T) + xi(card3(Y,Xn,Xm,1,0,1),T) +...
xi(card3(Y,Xn,Xm,1,1,0),T) + xi(card3(Y,Xn,Xm,1,1,1),T) );
function I = mut_inf(Y,Xn)
I = H1(Y) + H1(Xn) - H2(Y,Xn);
function I = cond_mut_inf(Y,Xn,Xm)
I = H2(Y,Xm) - H1(Xm) - H3(Y,Xn,Xm) + H2(Xn,Xm);
function MI = calc_mi(X, nu, k, y, Hy, T)
[Ux, ~, Idx] = unique(X(:,nu(1:k)), 'rows');
Px = histc(Idx, 1:size(Ux,1))/T; Hx = Hp(Px);
[Uyx, ~, Idyx] = unique([y, X(:,nu(1:k))], 'rows');
Pyx = histc(Idyx, 1:size(Uyx,1))/T; Hyx = Hp(Pyx);
MI = Hy + Hx - Hyx;
function H = Hp(P)
if abs(sum(P)-1) > 1000*eps
error('Sum of probabilities not equal to 1: %0.4f', sum(P));
end
Ok = (P > eps) & (P < 1-eps);
H = sum(P(Ok).*log2(1./P(Ok)));
function color = calc_color(k, smax)
spct = smax(1:k)/sum(smax(1:k));
red = spct/max(spct);
color = [red(k) 0 0];
function draw_line_tri(t,d, color,linewidth, xm,xM,ym,yM)
if (t == 0)
xstart = xm; xend = xM;
ystart = d; yend = d;
elseif (t == pi/2)
xstart = -d; xend = -d;
ystart = ym; yend = yM;
else
ystart = d/cos(t) + xm*tan(t);
if ystart < ym
ystart = ym; xstart = ym/tan(t)-d/sin(t);
elseif ystart > yM
ystart = yM; xstart = yM/tan(t)-d/sin(t);
else
xstart = xm;
end
yend = d/cos(t) + xM*tan(t);
if yend > yM
yend = yM; xend = yM/tan(t)-d/sin(t);
elseif yend < ym
yend = ym; xend = ym/tan(t)-d/sin(t);
else
xend = xM;
end
end
x = linspace(xstart, xend, 20);
y = linspace(ystart, yend, 20);
hold on; plot(x,y,'color',color,'linewidth',linewidth); hold off;
00001000101111110111010100001011011111010 ... 100 0
00001010101111110111010100001011011011010 ... 000 0
10001010100101110010010000000110011001010 ... 100 1
......................................... ... ... .
10001010100101110011010000000110011001010 ... 110 0
H(Y) = 0.6930 bits
----------------------------------------- ... -----
s [1] = 0.038 0.037 0.012 0.021 0.000 ... 0.065
-> max s(n) = 0.092 - nu(1) = 122
cmi[1] = 0.068 0.086 0.000 0.000 0.000 ... 0.120
----------------------------------------- ... -----
...
----------------------------------------- ... -----
s [15] = 0.001 0.001 -0.000 -0.000 0.000 ... 0.003
-> max s(n) = 0.013 - nu(15) = 7
cmi[15] = 0.006 0.000 0.016 0.029 0.000 ... 0.010
Press any key to continue...