% CMIM - Select Lines - 02/03/12
% Jean-Marc Berthommé
%
% - 01/05/11:
%   . file created to illustrate François Fleuret's paper (2004):
%     "Fast Binary Feature Selection with Conditional Mutual Information"
%   . sample of points inside or outside a circle on which we want to
%     select discriminant lines
% - 01/06/11:
%   . naive algorithm final implementation
% - 01/10/11:
%   . data display improved
%   . selected lines color goes from flashy red (1st) to dark black (Kth)
%     in reference to the simulated annealing
% - 01/28/11:
%   . entropies expressed in bits
%   . H1(y) drawn in red on figure(2) to show the learning limit
% - 02/03/12:
%   . mutual information between the labels and the selected features added

%****************************  Main Loop  ********************************%
function select_lines
R = 0.5;  % radius of the positive points
T = 5000; % number of points - 20  - 500 - 10000
N =  500; % number of features - 8 -  50 -  1000
K =   15; % number of features to select - 4-100

% K features to select
s    = zeros(1,N); % scores up to date
cmi  = zeros(K,N); % conditional mutual informations
nu   = zeros(K,1); % index  of the selected features
smax = zeros(K,1); % values of the selected features
mi   = zeros(K,1); % global mutual informations y <~> Xnu

% Generate data
RandStream.setDefaultStream(RandStream('mt19937ar','seed',sum(100*clock)));
x  = generate_points(T);
y  = generate_labels(x, R);
fl = generate_line_features(N);

% Screen size - Show crude data
scr  = get(0,'ScreenSize');
scrX = scr(3)/3; scrY = scr(4)/2; scrW = scrX; scrH = scrY-50*2;
f1 = figure(1); set(f1,'Position', [2*scrX scrY scrW scrH]);
show_points(x,y); show_lines(fl);
f2 = figure(2); set(f2,'Position', [2*scrX   0  scrW scrH]);

% Boolean matrix of the points belonging or not to the features: TxN
X = build_matrix(x,fl);
Hy = H1(y); %disp_data(X, y, Hy);

for n=1:N
    s(n) = mut_inf(y,X(:,n));
end

for k=1:K
    % Select best informative feature
    [smax(k), nu(k)] = max(s);
    disp_scores(s, k, nu, smax(k));
    mi(k) = calc_mi(X, nu, k, y, Hy, T); show_cmi(smax, mi, Hy);

    % Draw the best selected feature
    color = calc_color(k, smax);
    tbest = fl(1,nu(k)); dbest = fl(2,nu(k));
    figure(1); draw_line_tri(tbest,dbest, color,2, -1,1,-1,1);

    for n=1:N
        cmi(k,n) = cond_mut_inf(y, X(:,n), X(:,nu(k)));
        s(n) = min( s(n), cmi(k,n) );
    end
    disp_cmi(k, cmi);
end

fprintf('\nPress any key to continue...\n');
pause; close all; clear all;

%****************************  Functions  ********************************%
function x=generate_points(T)
% Generate uniformly T 2D points x in ]-1;1[x]-1;1[
x = 2*rand(T,2) - 1; % NB: we are in R(O,x,y)

function y=generate_labels(x, radius)
% Assign a label y to each points of the set of points x :
% y =  1 if OM <= radius
% y =  0 if OM >  radius
y = (sqrt(x(:,1).^2+x(:,2).^2) <= radius);

function f=generate_line_features(N)
% half-space : - x.sin(t) + y.cos(t) - d < 0 or > 0
t = pi*rand(1,N);          % angles theta   - ]0;pi[
d = 2^0.5*(2*rand(1,N)-1); % distances to O - ]-1;1[
s = sign((2*rand(1,N)-1)); % signs          - {-1;+1}

f = [t;d;s];

%*******************************  Graphics  ******************************%
function show_points(x, y)
plot(x(y==1,1), x(y==1,2),'g.', x(y==0,1),x(y==0,2),'k.');
axis equal; axis([-1 1 -1 1]); legend('y=1','y=0');

function show_lines(f)

for i=1:size(f,2)
   t = f(1,i); d = f(2,i);
   draw_line_tri(t,d, 'b',1, -1,1,-1,1);
end

function show_cmi(smax, mi, Hy)
% Show an estimate of the global mutual information shared between the
% different selected features {Xk}1<=k<=K and the points {x}1<=t<=T
K = size(smax,1);
xg = linspace(0,K+0.5,10);
Hyg = linspace(Hy,Hy,10);

figure(2);
subplot(2,1,1); bar(smax,'r');
xlabel('k'); ylabel('max_{Xn} min_{1<=k<=K} I(Y;Xn|X_{nu(k)})');
title('Conditional Mutual Information of each best feature Xk|1<=k<=K');

subplot(2,1,2); bar(mi,'b');
xlabel('k'); ylabel('I(Y;{Xk}1<=k<=K)');
title('Global Mutual Information I(Y;{Xk}1<=k<=K) evolution');
hold on; plot(xg, Hyg,'r'); hold off; % Hy boundary

%*******************************   Data  *********************************%
function X = build_matrix(x,f)
T = size(x,1);
N = size(f,2);
X = false(T,N); % zeros(T,N);

for i=1:T
    xp = x(i,1); yp = x(i,2);
    % figure(1); hold on; plot(xp,yp,'r.');

    % Test if the point i ("p") belongs to the feature j ("f")
    for j=1:N
        tf = f(1,j); df = f(2,j); sf = f(3,j);
        % draw_line_tri(tf,df, 'c',2, -1,1,-1,1);
        X(i,j) = (sign(-xp*sin(tf) + yp*cos(tf) - df) == sf);
    end
end

% function disp_data(X, Y, HY)
%
% for i=1:size(X,1)
%    for j=1:size(X,2)
%        fprintf(1,'%d',X(i,j));
%    end
%    fprintf(1,' %d\n',Y(i));
% end
% fprintf(1,'\nH(Y) = %0.4f bits \n\n', HY);

function disp_scores(s, k, nu, smax)
fprintf(1,'---------');
for n=1:size(s,2);
    fprintf(1,'------');
end
fprintf(1,'\n');

fprintf(1,'s  [%d] = ', k);
for n=1:size(s,2);
    fprintf(1,'%0.3f ', s(n));
end
fprintf(1,'\n');
fprintf(1,'-> max s(n) = %0.3f - nu(%d) = %d\n',  smax, k, nu(k));

function disp_cmi(k, cmi)
fprintf(1,'cmi[%d] = ', k);
for n=1:size(cmi,2);
    fprintf(1,'%0.3f ', cmi(k,n));
end
fprintf(1,'\n');

%******************************  Entropy  ********************************%
function card = card1(x,u)
% x : boolean vector
% u : boolean value (0/1)
card = sum(x==u);

function card = card2(x,y,u,v)
% x,y : boolean vectors
% u,v : boolean values (0/1)
card = sum(x==u & y==v);

function card = card3(x,y,z,u,v,w)
% x,y,z : boolean vectors
% u,v,w : boolean values (0/1)
card = sum(x==u & y==v & z==w);

function xi = xi(x,T)
% T : number of points
if x==0, xi = 0; % x*log(x) -> 0 when x -> 0+
else     xi = x*log2(x)/T;
end

function H = H1(Y)
% T : number of points
% Y : boolean vector in column
T = size(Y,1);
H = log2(T) - ( xi(card1(Y,0),T) + xi(card1(Y,1),T) );

function H = H2(Y,Xn)
% T : number of points
% Y : boolean vector in column
T = size(Y,1);
H = log2(T) - ( xi(card2(Y,Xn,0,0),T) + xi(card2(Y,Xn,0,1),T) +...
                xi(card2(Y,Xn,1,0),T) + xi(card2(Y,Xn,1,1),T) );

function H = H3(Y,Xn,Xm)
% T : number of points
% Y : boolean vector in column
T = size(Y,1);
H = log2(T) - (xi(card3(Y,Xn,Xm,0,0,0),T) + xi(card3(Y,Xn,Xm,0,0,1),T) +...
               xi(card3(Y,Xn,Xm,0,1,0),T) + xi(card3(Y,Xn,Xm,0,1,1),T) +...
               xi(card3(Y,Xn,Xm,1,0,0),T) + xi(card3(Y,Xn,Xm,1,0,1),T) +...
               xi(card3(Y,Xn,Xm,1,1,0),T) + xi(card3(Y,Xn,Xm,1,1,1),T) );

function I = mut_inf(Y,Xn)
% Mutual Information I(Y;Xn)
I = H1(Y) + H1(Xn) - H2(Y,Xn);

function I = cond_mut_inf(Y,Xn,Xm)
% Conditional Mutual Information I(Y;Xn|Xm)
I = H2(Y,Xm) - H1(Xm) - H3(Y,Xn,Xm) + H2(Xn,Xm);

function MI = calc_mi(X, nu, k, y, Hy, T)
% Global MI between the labels y and the features X(:,nu(1:k))
[Ux, ~, Idx] = unique(X(:,nu(1:k)), 'rows');
Px = histc(Idx, 1:size(Ux,1))/T; Hx = Hp(Px);

[Uyx, ~, Idyx] = unique([y, X(:,nu(1:k))], 'rows');
Pyx = histc(Idyx, 1:size(Uyx,1))/T; Hyx = Hp(Pyx);

MI = Hy + Hx - Hyx;

function H = Hp(P)
% Entropy of a state of n micro-states
% P: micro-states probabilities vector
if abs(sum(P)-1) > 1000*eps
    error('Sum of probabilities not equal to 1: %0.4f', sum(P));
end
Ok = (P > eps) & (P < 1-eps);   % avoids NaNs as p.log(p) -> 0 in 0+
H = sum(P(Ok).*log2(1./P(Ok)));

%****************************   Drawing   ********************************%
function color = calc_color(k, smax)
% Calculate the kth color with the appropriate contrast
% 100% : a flashy red - 0% : dark black
spct  = smax(1:k)/sum(smax(1:k));
red   = spct/max(spct);
color = [red(k) 0 0];

function draw_line_tri(t,d, color,linewidth, xm,xM,ym,yM)
% Draw line : display a line of equation : -x.sin(t) + y.cos(t) - d = 0
% angles are counterclockwise (image convention reversed)
if (t == 0)         % horizontal
    xstart = xm; xend = xM;
    ystart =  d; yend = d;
elseif (t == pi/2)  % vertical
    xstart = -d; xend = -d;
    ystart = ym; yend = yM;
else                % diagonal, ie the others
    ystart = d/cos(t) + xm*tan(t);
    if ystart < ym
        ystart = ym; xstart = ym/tan(t)-d/sin(t);
    elseif ystart > yM
        ystart = yM; xstart = yM/tan(t)-d/sin(t);
    else % if (ystart >= ym) && (ystart <= yM)
        xstart = xm;
    end

    yend = d/cos(t) + xM*tan(t);
    if yend > yM
        yend = yM; xend = yM/tan(t)-d/sin(t);
    elseif yend < ym
        yend = ym; xend = ym/tan(t)-d/sin(t);
    else % if (yend >= ym) && (yend <= yM)
        xend = xM;
    end
end

% trigonometric convention
x = linspace(xstart, xend, 20);
y = linspace(ystart, yend, 20);

hold on; plot(x,y,'color',color,'linewidth',linewidth); hold off;
00001000101111110111010100001011011111010 ... 100 0
00001010101111110111010100001011011011010 ... 000 0
10001010100101110010010000000110011001010 ... 100 1
......................................... ... ... .
10001010100101110011010000000110011001010 ... 110 0

H(Y) = 0.6930 bits 

----------------------------------------- ... -----
s  [1] = 0.038 0.037 0.012 0.021 0.000    ... 0.065 
-> max s(n) = 0.092 - nu(1) = 122
cmi[1] = 0.068 0.086 0.000 0.000 0.000    ... 0.120 
----------------------------------------- ... -----

...

----------------------------------------- ... -----
s  [15] = 0.001 0.001 -0.000 -0.000 0.000 ... 0.003
-> max s(n) = 0.013 - nu(15) = 7
cmi[15] = 0.006 0.000 0.016 0.029 0.000   ... 0.010

Press any key to continue...