Machine learning
XAI
Nov 22, 2020     7 minutes read

1. Why even bother explaining machine learning models?

2. How to do that?

The most popular approaches are:

SHAP

Useful resources:

lime

We pretty much build a linear regression model on predictions. Useful resources:

In general it was extremely hard to find proper explanation of what actually lime does. Ironically, the “explainer” which is meant to transform a complex, nonlinear model into an easy to grasp form is very complicated and non-intuitive itself. The pieces of the puzzle which I found particularly obscure are:

To sum up, despite my concerns, I like this method. I think it is a brilliant idea to explain the black box models, but we should not be as yolo-optimistic as the authors of the majority of articles on medium.com, towardsdatascience.com etc. are and we should not treat lime as a magical tool that finally solves the black-box explainability problem. Actually I am quite disappointed that the hype on data science leads to dishonest psuedo-papers and false promises (“with lime you can explain any black-box model, because it’s model agnostic” - well, really? Although it is true that it is model agnostic, the quality of the explanation may be very poor, even false and misleading if you do not do it carefully. Besides, in some cases it may not be possible to find a feasible linear approximation. Well, there are some ‘24-hour courses’ on data science, which may suggest that after 24 hours… you are a competent data scientist? This is a subject for another discussion ;) )

eli5

TODO

CAM - Class Activation Mapping

In contrast to the previous methods, which are model agnostic, CAM works specifically for convolutional neural networks trained on images. Similarily to shap and lime, it highlights these areas on a given image, which had the biggest impact on prediction, e.g. in case of predicting a cat, it would highlight cat’s mouth, ears and tail.

The concept of CAM is rather straightforward and intuitive: we calculate a weighted average of the channels in the last convolutional layers. Weights are provided by the network itself: they come from the linear layer which follows the convolutional layer. In result we recognize which channels from the last convolutional layer contributed the prediction the most, and if channel concentrated on a specific area of the image, which is often the case, we know which areas were the most important. The whole concept was presented in the article Learning Deep Features for Discriminative Localization.

Unfortunately I haven’t come across any decent implementation1 of CAM in pytorch (maybe there is one, but I couldn’t find it). I’ve seen examples of CAM in keras, but I don’t use it on a daily basis, and retraining a model in a different framework sounds like the last thing anyone would want to do. So here’s a short script I wrote with the help of Deep Learning for Coders with fastai and PyTorch:

class Hook:
    def hook_func(self, m, i, o): self.stored = o.detach().clone()

hook_output = Hook()
hook = learn.model[0].register_forward_hook(hook_output.hook_func)  # 1

from pathlib import Path
import matplotlib.pyplot as plt
import torch

def cam(i):
    path = Path("data/fiat-126").ls()[i]  # 2
    img = PILImage.create(path)
    x, = first(loader.test_dl([img]))

    with torch.no_grad(): output = learn.model.eval()(x)
    act = hook_output.stored[0]
    cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, act)

    x_dec = TensorImage(loader.train.decode((x,))[0][0])

    _, ax = plt.subplots()
    x_dec.show(ctx=ax)
    ax.imshow(cam_map[1].detach(), alpha=0.5, extent=(0,224,224,0), interpolation='bilinear', cmap='magma')

cam(0)

Two lines of this code may look slightly obscure should you not know the context:

At first I was quite disappointed with how CAM works, or, actually, doesn’t work. But maybe it just explained to me that my neural network made no sense, which is what I am still investigating. Despite the failure, I still believe in this method, as it seems it should work.


  1. After some time I finally found a proper implementation of Grad-CAM and I used it for this project. ↩︎