Skip to content

Prototype Selection for Interpretable Classification

Notifications You must be signed in to change notification settings

kumamt/Prototype_Selection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

66 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Interpretability

Most ML algorithms require the ability to explain the decision taken by a model easily. For example, in banks, when a person requests a bank loan. The bank should be able to explain why a particular decision (rejection or acceptance) is made. The goal of the interpretaility is to bring trust, privacy, fairness and robustness to the Machine Learning models in comparison with decision of black-box models that are hard to comprehend.

Prototype Selection for Interpretable Classification (PS)

Given a training data set $V={v_1,\cdots,v_n} \in \mathbb{R}^m$ of n data samples, where each sample is represents an m-dimensional feature vector. Each sample in the training set is associated with corresponding class label $y_1,\cdots,y_n \in {1,\cdots, L}$. PS scheme returns a set $\mathscr{P}_l$ for each class $l \in {1,\cdots, L}$. The returned output set is a condensed prototype set $\mathscr{P} ={\mathscr{P}_1,\cdots,\mathscr{P}_L} \subseteq V$ that can be seen as having an interpretable meaning. According to J. Bien and R. Tibshirani the prototype selection in PS is based on three important notion :

  • Prototypes should cover as many training samples of the same class $l$
  • Prototypes should cover few training samples from another class and,
  • The number of prototypes needed to cover data samples of a particular class should be as few as possible (also called sparsity).

The prototypes selected in PS are actual data points as they will add more interpretable meaning to the model. PS scheme intially is formed using Set Cover Integer problem. For a given radius of an epsilon ball (centered at chosen prototype) PS outputs minimum number of balls required to form a Cover while preserving the properties of prototypes. Then, is tranformed into l-prize collection problem and solving using two approach namely

  • Greedy Approach (recommended for large dataset)
  • Randommized Rounding Algorithm

The figure above shows visualisation of PS scheme for synthetic data (sklean moon data) for the chosen value of epsilon. A filled circle represents data points, and covers are represented by an unfilled circle centered at a data point (denoted by X) chosen as a prototype.

For further reading and mathematical understanding please refer J. Bien and R. Tibshirani.

Generalised Learning Vector Quantisation (GLVQ)

In order to understand the GLVQ, the prototype set has been considered as $W = {w_1,\cdots,w_l}$ for each class in the dataset where, $l\in {1,\cdots,L}$ and $L$ is the set of class labels. For comparison of samples among the prototypes, a distance matrix has been represented by $D_{(i,j)} = d(v_i,w_j)$ where $d$ is a differentiable dissimilarity measure. Next to this, a classifier (relative distance difference) function for GLVQ is defined as,

where, $w^+$ is the best matching correct prototype with $c(w^+)=c(v)$ and $w^-$ is the closest prototype belonging to the wrong or incorrect class $c(w^+)\neq c(v)$. For correct classification the distance $d(v, w^+)$ of the data point belonging to the correct class prototype should be smaller than the distance $d(v, w^-)$ of the prototype belonging to the incorrect class. In this case, the output of the classifier function are negative values, and hence, the cost function of $E_{GLVQ}$ is then an approximation of the overall classification error.

For further reading and understanding please refer A. Sato and K. Yamada

Experiments:

Iris data

Image below shows the selected prototype (X) to represent the data samples (filled o).

  • Unlike k-Nearest Neighbor storing whole data for prediction, in prototype selection scheme the condensed form of training data samples (prototypes) are only require to be stored saving large amount of memory.
  • For prediction it only utlises the distances to the selected prototypes (saving time required to compare whole data sample)

Comparison of GLVQ and GLVQ-PS