Author: Anna Hedström, PhD Candidate, TU Berlin and ATB Potsdam
Contact: anna.hedstroem@tu-berlin.de
Venue: INVICTA, Spring School
Date: March 21, 2024
Abstract:
In this tutorial, as a first part, we will take an in-depth look at some of the most recent developments in XAI evaluation and in particular, give a demonstration of how to perform XAI evaluation using the open-source library Quantus. We will answer questions such as:
In the second part of the tutorial, we will define and address the problem of meta-evaluation in XAI (i.e., the process of evaluating the evaluation method itself) which arises as we select and quantitatively compare explanation methods for a given model, dataset and task—where the use of multiple XAI metrics or evaluation techniques oftentimes lead to conflicting results. In this part, we will use the library MetaQuantus to characterise the performance of different XAI metrics and moreover, help select and identify a reliable metric for our chosen explainability context.
In the third part of the tutorial, we will investigate how different parameters influence the evaluation outcome, i.e., how different explanations methods rank.
First, we need to re-load a dataset and model.
Related Papers:
Identifying Reliable Estimators with MetaQuantus](https://openreview.net/pdf?id=j3FK00HyfU) by Hedström et al., 2023
Instructions: For this exercise, we work in a Google Colab
environment. It is also possible to run the notebook using a simple Jupyter Notebook
environment.
Quantus
and MetaQuantus
and (iii) run a real-world image classification example.Quantus
can be used with Tensorflow
and with other tasks/ data domains such as time-series, tabular or NLP data. We do not discuss global explanation methods.Import the main libraries Quantus
and MetaQuantus
and some supportive libraries.
!pip install captum torchvision medmnist datasets transformers --quiet
# Get the latest versions of quantus and metaquantus.
!pip install install git+https://github.com/understandable-machine-intelligence-lab/Quantus.git \
git+https://github.com/annahedstroem/MetaQuantus.git --quiet
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 9.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 510.5/510.5 kB 11.4 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 23.7/23.7 MB 41.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 823.6/823.6 kB 61.7 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 14.1/14.1 MB 63.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 731.7/731.7 MB 2.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 410.6/410.6 MB 3.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 121.6/121.6 MB 13.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.5/56.5 MB 30.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 124.2/124.2 MB 8.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 196.0/196.0 MB 4.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 166.0/166.0 MB 10.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 99.1/99.1 kB 14.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 79.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 88.4/88.4 kB 12.2 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 116.3/116.3 kB 17.0 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 194.1/194.1 kB 22.8 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 134.8/134.8 kB 19.9 MB/s eta 0:00:00 Building wheel for fire (setup.py) ... done Installing build dependencies ... done Getting requirements to build wheel ... done Preparing metadata (pyproject.toml) ... done Preparing metadata (setup.py) ... done ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 25.8 MB/s eta 0:00:00 Building wheel for quantus (pyproject.toml) ... done Building wheel for metaquantus (setup.py) ... done
Please restart the runtime session after running the above cell.
# Imports.
import quantus
import metaquantus
import glob
import gc
import tqdm
import os
import copy
import warnings
import torch
import torchvision
import captum
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
# Enable GPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Notebook settings.
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore", category=FutureWarning)
sns.set()
List all package requirements.
# !pip freeze
For the purpose of demonstration, in this exercise, we rely on a image-classification task with PyTorch
. That being said, Quantus
and MetaQuatus
do support other ML frameworks such as Tensorflow
and data domains e.g., time-series, tabular data and some NLP support.
We have prepared a small subset of ImageNet images (to download the full dataset, please find instructions here: https://image-net.org/download.php). The following will load inputs x_batch
, labels y_batch
and segmentation masks s_batch
.
!pip install gdown --quiet
!gdown https://drive.google.com/drive/folders/1ZmIkqvnt8_wXU3dLBWS9i88l-spS_FU1 --folder --quiet
!ls
assets sample_data
#@title 3.1.1 ImageNet class indices names
%%capture
CLASSES = {0: 'tench, Tinca tinca',
1: 'goldfish, Carassius auratus',
2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
3: 'tiger shark, Galeocerdo cuvieri',
4: 'hammerhead, hammerhead shark',
5: 'electric ray, crampfish, numbfish, torpedo',
6: 'stingray',
7: 'cock',
8: 'hen',
9: 'ostrich, Struthio camelus',
10: 'brambling, Fringilla montifringilla',
11: 'goldfinch, Carduelis carduelis',
12: 'house finch, linnet, Carpodacus mexicanus',
13: 'junco, snowbird',
14: 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
15: 'robin, American robin, Turdus migratorius',
16: 'bulbul',
17: 'jay',
18: 'magpie',
19: 'chickadee',
20: 'water ouzel, dipper',
21: 'kite',
22: 'bald eagle, American eagle, Haliaeetus leucocephalus',
23: 'vulture',
24: 'great grey owl, great gray owl, Strix nebulosa',
25: 'European fire salamander, Salamandra salamandra',
26: 'common newt, Triturus vulgaris',
27: 'eft',
28: 'spotted salamander, Ambystoma maculatum',
29: 'axolotl, mud puppy, Ambystoma mexicanum',
30: 'bullfrog, Rana catesbeiana',
31: 'tree frog, tree-frog',
32: 'tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui',
33: 'loggerhead, loggerhead turtle, Caretta caretta',
34: 'leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea',
35: 'mud turtle',
36: 'terrapin',
37: 'box turtle, box tortoise',
38: 'banded gecko',
39: 'common iguana, iguana, Iguana iguana',
40: 'American chameleon, anole, Anolis carolinensis',
41: 'whiptail, whiptail lizard',
42: 'agama',
43: 'frilled lizard, Chlamydosaurus kingi',
44: 'alligator lizard',
45: 'Gila monster, Heloderma suspectum',
46: 'green lizard, Lacerta viridis',
47: 'African chameleon, Chamaeleo chamaeleon',
48: 'Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis',
49: 'African crocodile, Nile crocodile, Crocodylus niloticus',
50: 'American alligator, Alligator mississipiensis',
51: 'triceratops',
52: 'thunder snake, worm snake, Carphophis amoenus',
53: 'ringneck snake, ring-necked snake, ring snake',
54: 'hognose snake, puff adder, sand viper',
55: 'green snake, grass snake',
56: 'king snake, kingsnake',
57: 'garter snake, grass snake',
58: 'water snake',
59: 'vine snake',
60: 'night snake, Hypsiglena torquata',
61: 'boa constrictor, Constrictor constrictor',
62: 'rock python, rock snake, Python sebae',
63: 'Indian cobra, Naja naja',
64: 'green mamba',
65: 'sea snake',
66: 'horned viper, cerastes, sand viper, horned asp, Cerastes cornutus',
67: 'diamondback, diamondback rattlesnake, Crotalus adamanteus',
68: 'sidewinder, horned rattlesnake, Crotalus cerastes',
69: 'trilobite',
70: 'harvestman, daddy longlegs, Phalangium opilio',
71: 'scorpion',
72: 'black and gold garden spider, Argiope aurantia',
73: 'barn spider, Araneus cavaticus',
74: 'garden spider, Aranea diademata',
75: 'black widow, Latrodectus mactans',
76: 'tarantula',
77: 'wolf spider, hunting spider',
78: 'tick',
79: 'centipede',
80: 'black grouse',
81: 'ptarmigan',
82: 'ruffed grouse, partridge, Bonasa umbellus',
83: 'prairie chicken, prairie grouse, prairie fowl',
84: 'peacock',
85: 'quail',
86: 'partridge',
87: 'African grey, African gray, Psittacus erithacus',
88: 'macaw',
89: 'sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita',
90: 'lorikeet',
91: 'coucal',
92: 'bee eater',
93: 'hornbill',
94: 'hummingbird',
95: 'jacamar',
96: 'toucan',
97: 'drake',
98: 'red-breasted merganser, Mergus serrator',
99: 'goose',
100: 'black swan, Cygnus atratus',
101: 'tusker',
102: 'echidna, spiny anteater, anteater',
103: 'platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus',
104: 'wallaby, brush kangaroo',
105: 'koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus',
106: 'wombat',
107: 'jellyfish',
108: 'sea anemone, anemone',
109: 'brain coral',
110: 'flatworm, platyhelminth',
111: 'nematode, nematode worm, roundworm',
112: 'conch',
113: 'snail',
114: 'slug',
115: 'sea slug, nudibranch',
116: 'chiton, coat-of-mail shell, sea cradle, polyplacophore',
117: 'chambered nautilus, pearly nautilus, nautilus',
118: 'Dungeness crab, Cancer magister',
119: 'rock crab, Cancer irroratus',
120: 'fiddler crab',
121: 'king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica',
122: 'American lobster, Northern lobster, Maine lobster, Homarus americanus',
123: 'spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish',
124: 'crayfish, crawfish, crawdad, crawdaddy',
125: 'hermit crab',
126: 'isopod',
127: 'white stork, Ciconia ciconia',
128: 'black stork, Ciconia nigra',
129: 'spoonbill',
130: 'flamingo',
131: 'little blue heron, Egretta caerulea',
132: 'American egret, great white heron, Egretta albus',
133: 'bittern',
134: 'crane',
135: 'limpkin, Aramus pictus',
136: 'European gallinule, Porphyrio porphyrio',
137: 'American coot, marsh hen, mud hen, water hen, Fulica americana',
138: 'bustard',
139: 'ruddy turnstone, Arenaria interpres',
140: 'red-backed sandpiper, dunlin, Erolia alpina',
141: 'redshank, Tringa totanus',
142: 'dowitcher',
143: 'oystercatcher, oyster catcher',
144: 'pelican',
145: 'king penguin, Aptenodytes patagonica',
146: 'albatross, mollymawk',
147: 'grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus',
148: 'killer whale, killer, orca, grampus, sea wolf, Orcinus orca',
149: 'dugong, Dugong dugon',
150: 'sea lion',
151: 'Chihuahua',
152: 'Japanese spaniel',
153: 'Maltese dog, Maltese terrier, Maltese',
154: 'Pekinese, Pekingese, Peke',
155: 'Shih-Tzu',
156: 'Blenheim spaniel',
157: 'papillon',
158: 'toy terrier',
159: 'Rhodesian ridgeback',
160: 'Afghan hound, Afghan',
161: 'basset, basset hound',
162: 'beagle',
163: 'bloodhound, sleuthhound',
164: 'bluetick',
165: 'black-and-tan coonhound',
166: 'Walker hound, Walker foxhound',
167: 'English foxhound',
168: 'redbone',
169: 'borzoi, Russian wolfhound',
170: 'Irish wolfhound',
171: 'Italian greyhound',
172: 'whippet',
173: 'Ibizan hound, Ibizan Podenco',
174: 'Norwegian elkhound, elkhound',
175: 'otterhound, otter hound',
176: 'Saluki, gazelle hound',
177: 'Scottish deerhound, deerhound',
178: 'Weimaraner',
179: 'Staffordshire bullterrier, Staffordshire bull terrier',
180: 'American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier',
181: 'Bedlington terrier',
182: 'Border terrier',
183: 'Kerry blue terrier',
184: 'Irish terrier',
185: 'Norfolk terrier',
186: 'Norwich terrier',
187: 'Yorkshire terrier',
188: 'wire-haired fox terrier',
189: 'Lakeland terrier',
190: 'Sealyham terrier, Sealyham',
191: 'Airedale, Airedale terrier',
192: 'cairn, cairn terrier',
193: 'Australian terrier',
194: 'Dandie Dinmont, Dandie Dinmont terrier',
195: 'Boston bull, Boston terrier',
196: 'miniature schnauzer',
197: 'giant schnauzer',
198: 'standard schnauzer',
199: 'Scotch terrier, Scottish terrier, Scottie',
200: 'Tibetan terrier, chrysanthemum dog',
201: 'silky terrier, Sydney silky',
202: 'soft-coated wheaten terrier',
203: 'West Highland white terrier',
204: 'Lhasa, Lhasa apso',
205: 'flat-coated retriever',
206: 'curly-coated retriever',
207: 'golden retriever',
208: 'Labrador retriever',
209: 'Chesapeake Bay retriever',
210: 'German short-haired pointer',
211: 'vizsla, Hungarian pointer',
212: 'English setter',
213: 'Irish setter, red setter',
214: 'Gordon setter',
215: 'Brittany spaniel',
216: 'clumber, clumber spaniel',
217: 'English springer, English springer spaniel',
218: 'Welsh springer spaniel',
219: 'cocker spaniel, English cocker spaniel, cocker',
220: 'Sussex spaniel',
221: 'Irish water spaniel',
222: 'kuvasz',
223: 'schipperke',
224: 'groenendael',
225: 'malinois',
226: 'briard',
227: 'kelpie',
228: 'komondor',
229: 'Old English sheepdog, bobtail',
230: 'Shetland sheepdog, Shetland sheep dog, Shetland',
231: 'collie',
232: 'Border collie',
233: 'Bouvier des Flandres, Bouviers des Flandres',
234: 'Rottweiler',
235: 'German shepherd, German shepherd dog, German police dog, alsatian',
236: 'Doberman, Doberman pinscher',
237: 'miniature pinscher',
238: 'Greater Swiss Mountain dog',
239: 'Bernese mountain dog',
240: 'Appenzeller',
241: 'EntleBucher',
242: 'boxer',
243: 'bull mastiff',
244: 'Tibetan mastiff',
245: 'French bulldog',
246: 'Great Dane',
247: 'Saint Bernard, St Bernard',
248: 'Eskimo dog, husky',
249: 'malamute, malemute, Alaskan malamute',
250: 'Siberian husky',
251: 'dalmatian, coach dog, carriage dog',
252: 'affenpinscher, monkey pinscher, monkey dog',
253: 'basenji',
254: 'pug, pug-dog',
255: 'Leonberg',
256: 'Newfoundland, Newfoundland dog',
257: 'Great Pyrenees',
258: 'Samoyed, Samoyede',
259: 'Pomeranian',
260: 'chow, chow chow',
261: 'keeshond',
262: 'Brabancon griffon',
263: 'Pembroke, Pembroke Welsh corgi',
264: 'Cardigan, Cardigan Welsh corgi',
265: 'toy poodle',
266: 'miniature poodle',
267: 'standard poodle',
268: 'Mexican hairless',
269: 'timber wolf, grey wolf, gray wolf, Canis lupus',
270: 'white wolf, Arctic wolf, Canis lupus tundrarum',
271: 'red wolf, maned wolf, Canis rufus, Canis niger',
272: 'coyote, prairie wolf, brush wolf, Canis latrans',
273: 'dingo, warrigal, warragal, Canis dingo',
274: 'dhole, Cuon alpinus',
275: 'African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus',
276: 'hyena, hyaena',
277: 'red fox, Vulpes vulpes',
278: 'kit fox, Vulpes macrotis',
279: 'Arctic fox, white fox, Alopex lagopus',
280: 'grey fox, gray fox, Urocyon cinereoargenteus',
281: 'tabby, tabby cat',
282: 'tiger cat',
283: 'Persian cat',
284: 'Siamese cat, Siamese',
285: 'Egyptian cat',
286: 'cougar, puma, catamount, mountain lion, painter, panther, Felis concolor',
287: 'lynx, catamount',
288: 'leopard, Panthera pardus',
289: 'snow leopard, ounce, Panthera uncia',
290: 'jaguar, panther, Panthera onca, Felis onca',
291: 'lion, king of beasts, Panthera leo',
292: 'tiger, Panthera tigris',
293: 'cheetah, chetah, Acinonyx jubatus',
294: 'brown bear, bruin, Ursus arctos',
295: 'American black bear, black bear, Ursus americanus, Euarctos americanus',
296: 'ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus',
297: 'sloth bear, Melursus ursinus, Ursus ursinus',
298: 'mongoose',
299: 'meerkat, mierkat',
300: 'tiger beetle',
301: 'ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle',
302: 'ground beetle, carabid beetle',
303: 'long-horned beetle, longicorn, longicorn beetle',
304: 'leaf beetle, chrysomelid',
305: 'dung beetle',
306: 'rhinoceros beetle',
307: 'weevil',
308: 'fly',
309: 'bee',
310: 'ant, emmet, pismire',
311: 'grasshopper, hopper',
312: 'cricket',
313: 'walking stick, walkingstick, stick insect',
314: 'cockroach, roach',
315: 'mantis, mantid',
316: 'cicada, cicala',
317: 'leafhopper',
318: 'lacewing, lacewing fly',
319: "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
320: 'damselfly',
321: 'admiral',
322: 'ringlet, ringlet butterfly',
323: 'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus',
324: 'cabbage butterfly',
325: 'sulphur butterfly, sulfur butterfly',
326: 'lycaenid, lycaenid butterfly',
327: 'starfish, sea star',
328: 'sea urchin',
329: 'sea cucumber, holothurian',
330: 'wood rabbit, cottontail, cottontail rabbit',
331: 'hare',
332: 'Angora, Angora rabbit',
333: 'hamster',
334: 'porcupine, hedgehog',
335: 'fox squirrel, eastern fox squirrel, Sciurus niger',
336: 'marmot',
337: 'beaver',
338: 'guinea pig, Cavia cobaya',
339: 'sorrel',
340: 'zebra',
341: 'hog, pig, grunter, squealer, Sus scrofa',
342: 'wild boar, boar, Sus scrofa',
343: 'warthog',
344: 'hippopotamus, hippo, river horse, Hippopotamus amphibius',
345: 'ox',
346: 'water buffalo, water ox, Asiatic buffalo, Bubalus bubalis',
347: 'bison',
348: 'ram, tup',
349: 'bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis',
350: 'ibex, Capra ibex',
351: 'hartebeest',
352: 'impala, Aepyceros melampus',
353: 'gazelle',
354: 'Arabian camel, dromedary, Camelus dromedarius',
355: 'llama',
356: 'weasel',
357: 'mink',
358: 'polecat, fitch, foulmart, foumart, Mustela putorius',
359: 'black-footed ferret, ferret, Mustela nigripes',
360: 'otter',
361: 'skunk, polecat, wood pussy',
362: 'badger',
363: 'armadillo',
364: 'three-toed sloth, ai, Bradypus tridactylus',
365: 'orangutan, orang, orangutang, Pongo pygmaeus',
366: 'gorilla, Gorilla gorilla',
367: 'chimpanzee, chimp, Pan troglodytes',
368: 'gibbon, Hylobates lar',
369: 'siamang, Hylobates syndactylus, Symphalangus syndactylus',
370: 'guenon, guenon monkey',
371: 'patas, hussar monkey, Erythrocebus patas',
372: 'baboon',
373: 'macaque',
374: 'langur',
375: 'colobus, colobus monkey',
376: 'proboscis monkey, Nasalis larvatus',
377: 'marmoset',
378: 'capuchin, ringtail, Cebus capucinus',
379: 'howler monkey, howler',
380: 'titi, titi monkey',
381: 'spider monkey, Ateles geoffroyi',
382: 'squirrel monkey, Saimiri sciureus',
383: 'Madagascar cat, ring-tailed lemur, Lemur catta',
384: 'indri, indris, Indri indri, Indri brevicaudatus',
385: 'Indian elephant, Elephas maximus',
386: 'African elephant, Loxodonta africana',
387: 'lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens',
388: 'giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca',
389: 'barracouta, snoek',
390: 'eel',
391: 'coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch',
392: 'rock beauty, Holocanthus tricolor',
393: 'anemone fish',
394: 'sturgeon',
395: 'gar, garfish, garpike, billfish, Lepisosteus osseus',
396: 'lionfish',
397: 'puffer, pufferfish, blowfish, globefish',
398: 'abacus',
399: 'abaya',
400: "academic gown, academic robe, judge's robe",
401: 'accordion, piano accordion, squeeze box',
402: 'acoustic guitar',
403: 'aircraft carrier, carrier, flattop, attack aircraft carrier',
404: 'airliner',
405: 'airship, dirigible',
406: 'altar',
407: 'ambulance',
408: 'amphibian, amphibious vehicle',
409: 'analog clock',
410: 'apiary, bee house',
411: 'apron',
412: 'ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin',
413: 'assault rifle, assault gun',
414: 'backpack, back pack, knapsack, packsack, rucksack, haversack',
415: 'bakery, bakeshop, bakehouse',
416: 'balance beam, beam',
417: 'balloon',
418: 'ballpoint, ballpoint pen, ballpen, Biro',
419: 'Band Aid',
420: 'banjo',
421: 'bannister, banister, balustrade, balusters, handrail',
422: 'barbell',
423: 'barber chair',
424: 'barbershop',
425: 'barn',
426: 'barometer',
427: 'barrel, cask',
428: 'barrow, garden cart, lawn cart, wheelbarrow',
429: 'baseball',
430: 'basketball',
431: 'bassinet',
432: 'bassoon',
433: 'bathing cap, swimming cap',
434: 'bath towel',
435: 'bathtub, bathing tub, bath, tub',
436: 'beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon',
437: 'beacon, lighthouse, beacon light, pharos',
438: 'beaker',
439: 'bearskin, busby, shako',
440: 'beer bottle',
441: 'beer glass',
442: 'bell cote, bell cot',
443: 'bib',
444: 'bicycle-built-for-two, tandem bicycle, tandem',
445: 'bikini, two-piece',
446: 'binder, ring-binder',
447: 'binoculars, field glasses, opera glasses',
448: 'birdhouse',
449: 'boathouse',
450: 'bobsled, bobsleigh, bob',
451: 'bolo tie, bolo, bola tie, bola',
452: 'bonnet, poke bonnet',
453: 'bookcase',
454: 'bookshop, bookstore, bookstall',
455: 'bottlecap',
456: 'bow',
457: 'bow tie, bow-tie, bowtie',
458: 'brass, memorial tablet, plaque',
459: 'brassiere, bra, bandeau',
460: 'breakwater, groin, groyne, mole, bulwark, seawall, jetty',
461: 'breastplate, aegis, egis',
462: 'broom',
463: 'bucket, pail',
464: 'buckle',
465: 'bulletproof vest',
466: 'bullet train, bullet',
467: 'butcher shop, meat market',
468: 'cab, hack, taxi, taxicab',
469: 'caldron, cauldron',
470: 'candle, taper, wax light',
471: 'cannon',
472: 'canoe',
473: 'can opener, tin opener',
474: 'cardigan',
475: 'car mirror',
476: 'carousel, carrousel, merry-go-round, roundabout, whirligig',
477: "carpenter's kit, tool kit",
478: 'carton',
479: 'car wheel',
480: 'cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM',
481: 'cassette',
482: 'cassette player',
483: 'castle',
484: 'catamaran',
485: 'CD player',
486: 'cello, violoncello',
487: 'cellular telephone, cellular phone, cellphone, cell, mobile phone',
488: 'chain',
489: 'chainlink fence',
490: 'chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour',
491: 'chain saw, chainsaw',
492: 'chest',
493: 'chiffonier, commode',
494: 'chime, bell, gong',
495: 'china cabinet, china closet',
496: 'Christmas stocking',
497: 'church, church building',
498: 'cinema, movie theater, movie theatre, movie house, picture palace',
499: 'cleaver, meat cleaver, chopper',
500: 'cliff dwelling',
501: 'cloak',
502: 'clog, geta, patten, sabot',
503: 'cocktail shaker',
504: 'coffee mug',
505: 'coffeepot',
506: 'coil, spiral, volute, whorl, helix',
507: 'combination lock',
508: 'computer keyboard, keypad',
509: 'confectionery, confectionary, candy store',
510: 'container ship, containership, container vessel',
511: 'convertible',
512: 'corkscrew, bottle screw',
513: 'cornet, horn, trumpet, trump',
514: 'cowboy boot',
515: 'cowboy hat, ten-gallon hat',
516: 'cradle',
517: 'crane',
518: 'crash helmet',
519: 'crate',
520: 'crib, cot',
521: 'Crock Pot',
522: 'croquet ball',
523: 'crutch',
524: 'cuirass',
525: 'dam, dike, dyke',
526: 'desk',
527: 'desktop computer',
528: 'dial telephone, dial phone',
529: 'diaper, nappy, napkin',
530: 'digital clock',
531: 'digital watch',
532: 'dining table, board',
533: 'dishrag, dishcloth',
534: 'dishwasher, dish washer, dishwashing machine',
535: 'disk brake, disc brake',
536: 'dock, dockage, docking facility',
537: 'dogsled, dog sled, dog sleigh',
538: 'dome',
539: 'doormat, welcome mat',
540: 'drilling platform, offshore rig',
541: 'drum, membranophone, tympan',
542: 'drumstick',
543: 'dumbbell',
544: 'Dutch oven',
545: 'electric fan, blower',
546: 'electric guitar',
547: 'electric locomotive',
548: 'entertainment center',
549: 'envelope',
550: 'espresso maker',
551: 'face powder',
552: 'feather boa, boa',
553: 'file, file cabinet, filing cabinet',
554: 'fireboat',
555: 'fire engine, fire truck',
556: 'fire screen, fireguard',
557: 'flagpole, flagstaff',
558: 'flute, transverse flute',
559: 'folding chair',
560: 'football helmet',
561: 'forklift',
562: 'fountain',
563: 'fountain pen',
564: 'four-poster',
565: 'freight car',
566: 'French horn, horn',
567: 'frying pan, frypan, skillet',
568: 'fur coat',
569: 'garbage truck, dustcart',
570: 'gasmask, respirator, gas helmet',
571: 'gas pump, gasoline pump, petrol pump, island dispenser',
572: 'goblet',
573: 'go-kart',
574: 'golf ball',
575: 'golfcart, golf cart',
576: 'gondola',
577: 'gong, tam-tam',
578: 'gown',
579: 'grand piano, grand',
580: 'greenhouse, nursery, glasshouse',
581: 'grille, radiator grille',
582: 'grocery store, grocery, food market, market',
583: 'guillotine',
584: 'hair slide',
585: 'hair spray',
586: 'half track',
587: 'hammer',
588: 'hamper',
589: 'hand blower, blow dryer, blow drier, hair dryer, hair drier',
590: 'hand-held computer, hand-held microcomputer',
591: 'handkerchief, hankie, hanky, hankey',
592: 'hard disc, hard disk, fixed disk',
593: 'harmonica, mouth organ, harp, mouth harp',
594: 'harp',
595: 'harvester, reaper',
596: 'hatchet',
597: 'holster',
598: 'home theater, home theatre',
599: 'honeycomb',
600: 'hook, claw',
601: 'hoopskirt, crinoline',
602: 'horizontal bar, high bar',
603: 'horse cart, horse-cart',
604: 'hourglass',
605: 'iPod',
606: 'iron, smoothing iron',
607: "jack-o'-lantern",
608: 'jean, blue jean, denim',
609: 'jeep, landrover',
610: 'jersey, T-shirt, tee shirt',
611: 'jigsaw puzzle',
612: 'jinrikisha, ricksha, rickshaw',
613: 'joystick',
614: 'kimono',
615: 'knee pad',
616: 'knot',
617: 'lab coat, laboratory coat',
618: 'ladle',
619: 'lampshade, lamp shade',
620: 'laptop, laptop computer',
621: 'lawn mower, mower',
622: 'lens cap, lens cover',
623: 'letter opener, paper knife, paperknife',
624: 'library',
625: 'lifeboat',
626: 'lighter, light, igniter, ignitor',
627: 'limousine, limo',
628: 'liner, ocean liner',
629: 'lipstick, lip rouge',
630: 'Loafer',
631: 'lotion',
632: 'loudspeaker, speaker, speaker unit, loudspeaker system, speaker system',
633: "loupe, jeweler's loupe",
634: 'lumbermill, sawmill',
635: 'magnetic compass',
636: 'mailbag, postbag',
637: 'mailbox, letter box',
638: 'maillot',
639: 'maillot, tank suit',
640: 'manhole cover',
641: 'maraca',
642: 'marimba, xylophone',
643: 'mask',
644: 'matchstick',
645: 'maypole',
646: 'maze, labyrinth',
647: 'measuring cup',
648: 'medicine chest, medicine cabinet',
649: 'megalith, megalithic structure',
650: 'microphone, mike',
651: 'microwave, microwave oven',
652: 'military uniform',
653: 'milk can',
654: 'minibus',
655: 'miniskirt, mini',
656: 'minivan',
657: 'missile',
658: 'mitten',
659: 'mixing bowl',
660: 'mobile home, manufactured home',
661: 'Model T',
662: 'modem',
663: 'monastery',
664: 'monitor',
665: 'moped',
666: 'mortar',
667: 'mortarboard',
668: 'mosque',
669: 'mosquito net',
670: 'motor scooter, scooter',
671: 'mountain bike, all-terrain bike, off-roader',
672: 'mountain tent',
673: 'mouse, computer mouse',
674: 'mousetrap',
675: 'moving van',
676: 'muzzle',
677: 'nail',
678: 'neck brace',
679: 'necklace',
680: 'nipple',
681: 'notebook, notebook computer',
682: 'obelisk',
683: 'oboe, hautboy, hautbois',
684: 'ocarina, sweet potato',
685: 'odometer, hodometer, mileometer, milometer',
686: 'oil filter',
687: 'organ, pipe organ',
688: 'oscilloscope, scope, cathode-ray oscilloscope, CRO',
689: 'overskirt',
690: 'oxcart',
691: 'oxygen mask',
692: 'packet',
693: 'paddle, boat paddle',
694: 'paddlewheel, paddle wheel',
695: 'padlock',
696: 'paintbrush',
697: "pajama, pyjama, pj's, jammies",
698: 'palace',
699: 'panpipe, pandean pipe, syrinx',
700: 'paper towel',
701: 'parachute, chute',
702: 'parallel bars, bars',
703: 'park bench',
704: 'parking meter',
705: 'passenger car, coach, carriage',
706: 'patio, terrace',
707: 'pay-phone, pay-station',
708: 'pedestal, plinth, footstall',
709: 'pencil box, pencil case',
710: 'pencil sharpener',
711: 'perfume, essence',
712: 'Petri dish',
713: 'photocopier',
714: 'pick, plectrum, plectron',
715: 'pickelhaube',
716: 'picket fence, paling',
717: 'pickup, pickup truck',
718: 'pier',
719: 'piggy bank, penny bank',
720: 'pill bottle',
721: 'pillow',
722: 'ping-pong ball',
723: 'pinwheel',
724: 'pirate, pirate ship',
725: 'pitcher, ewer',
726: "plane, carpenter's plane, woodworking plane",
727: 'planetarium',
728: 'plastic bag',
729: 'plate rack',
730: 'plow, plough',
731: "plunger, plumber's helper",
732: 'Polaroid camera, Polaroid Land camera',
733: 'pole',
734: 'police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria',
735: 'poncho',
736: 'pool table, billiard table, snooker table',
737: 'pop bottle, soda bottle',
738: 'pot, flowerpot',
739: "potter's wheel",
740: 'power drill',
741: 'prayer rug, prayer mat',
742: 'printer',
743: 'prison, prison house',
744: 'projectile, missile',
745: 'projector',
746: 'puck, hockey puck',
747: 'punching bag, punch bag, punching ball, punchball',
748: 'purse',
749: 'quill, quill pen',
750: 'quilt, comforter, comfort, puff',
751: 'racer, race car, racing car',
752: 'racket, racquet',
753: 'radiator',
754: 'radio, wireless',
755: 'radio telescope, radio reflector',
756: 'rain barrel',
757: 'recreational vehicle, RV, R.V.',
758: 'reel',
759: 'reflex camera',
760: 'refrigerator, icebox',
761: 'remote control, remote',
762: 'restaurant, eating house, eating place, eatery',
763: 'revolver, six-gun, six-shooter',
764: 'rifle',
765: 'rocking chair, rocker',
766: 'rotisserie',
767: 'rubber eraser, rubber, pencil eraser',
768: 'rugby ball',
769: 'rule, ruler',
770: 'running shoe',
771: 'safe',
772: 'safety pin',
773: 'saltshaker, salt shaker',
774: 'sandal',
775: 'sarong',
776: 'sax, saxophone',
777: 'scabbard',
778: 'scale, weighing machine',
779: 'school bus',
780: 'schooner',
781: 'scoreboard',
782: 'screen, CRT screen',
783: 'screw',
784: 'screwdriver',
785: 'seat belt, seatbelt',
786: 'sewing machine',
787: 'shield, buckler',
788: 'shoe shop, shoe-shop, shoe store',
789: 'shoji',
790: 'shopping basket',
791: 'shopping cart',
792: 'shovel',
793: 'shower cap',
794: 'shower curtain',
795: 'ski',
796: 'ski mask',
797: 'sleeping bag',
798: 'slide rule, slipstick',
799: 'sliding door',
800: 'slot, one-armed bandit',
801: 'snorkel',
802: 'snowmobile',
803: 'snowplow, snowplough',
804: 'soap dispenser',
805: 'soccer ball',
806: 'sock',
807: 'solar dish, solar collector, solar furnace',
808: 'sombrero',
809: 'soup bowl',
810: 'space bar',
811: 'space heater',
812: 'space shuttle',
813: 'spatula',
814: 'speedboat',
815: "spider web, spider's web",
816: 'spindle',
817: 'sports car, sport car',
818: 'spotlight, spot',
819: 'stage',
820: 'steam locomotive',
821: 'steel arch bridge',
822: 'steel drum',
823: 'stethoscope',
824: 'stole',
825: 'stone wall',
826: 'stopwatch, stop watch',
827: 'stove',
828: 'strainer',
829: 'streetcar, tram, tramcar, trolley, trolley car',
830: 'stretcher',
831: 'studio couch, day bed',
832: 'stupa, tope',
833: 'submarine, pigboat, sub, U-boat',
834: 'suit, suit of clothes',
835: 'sundial',
836: 'sunglass',
837: 'sunglasses, dark glasses, shades',
838: 'sunscreen, sunblock, sun blocker',
839: 'suspension bridge',
840: 'swab, swob, mop',
841: 'sweatshirt',
842: 'swimming trunks, bathing trunks',
843: 'swing',
844: 'switch, electric switch, electrical switch',
845: 'syringe',
846: 'table lamp',
847: 'tank, army tank, armored combat vehicle, armoured combat vehicle',
848: 'tape player',
849: 'teapot',
850: 'teddy, teddy bear',
851: 'television, television system',
852: 'tennis ball',
853: 'thatch, thatched roof',
854: 'theater curtain, theatre curtain',
855: 'thimble',
856: 'thresher, thrasher, threshing machine',
857: 'throne',
858: 'tile roof',
859: 'toaster',
860: 'tobacco shop, tobacconist shop, tobacconist',
861: 'toilet seat',
862: 'torch',
863: 'totem pole',
864: 'tow truck, tow car, wrecker',
865: 'toyshop',
866: 'tractor',
867: 'trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi',
868: 'tray',
869: 'trench coat',
870: 'tricycle, trike, velocipede',
871: 'trimaran',
872: 'tripod',
873: 'triumphal arch',
874: 'trolleybus, trolley coach, trackless trolley',
875: 'trombone',
876: 'tub, vat',
877: 'turnstile',
878: 'typewriter keyboard',
879: 'umbrella',
880: 'unicycle, monocycle',
881: 'upright, upright piano',
882: 'vacuum, vacuum cleaner',
883: 'vase',
884: 'vault',
885: 'velvet',
886: 'vending machine',
887: 'vestment',
888: 'viaduct',
889: 'violin, fiddle',
890: 'volleyball',
891: 'waffle iron',
892: 'wall clock',
893: 'wallet, billfold, notecase, pocketbook',
894: 'wardrobe, closet, press',
895: 'warplane, military plane',
896: 'washbasin, handbasin, washbowl, lavabo, wash-hand basin',
897: 'washer, automatic washer, washing machine',
898: 'water bottle',
899: 'water jug',
900: 'water tower',
901: 'whiskey jug',
902: 'whistle',
903: 'wig',
904: 'window screen',
905: 'window shade',
906: 'Windsor tie',
907: 'wine bottle',
908: 'wing',
909: 'wok',
910: 'wooden spoon',
911: 'wool, woolen, woollen',
912: 'worm fence, snake fence, snake-rail fence, Virginia fence',
913: 'wreck',
914: 'yawl',
915: 'yurt',
916: 'web site, website, internet site, site',
917: 'comic book',
918: 'crossword puzzle, crossword',
919: 'street sign',
920: 'traffic light, traffic signal, stoplight',
921: 'book jacket, dust cover, dust jacket, dust wrapper',
922: 'menu',
923: 'plate',
924: 'guacamole',
925: 'consomme',
926: 'hot pot, hotpot',
927: 'trifle',
928: 'ice cream, icecream',
929: 'ice lolly, lolly, lollipop, popsicle',
930: 'French loaf',
931: 'bagel, beigel',
932: 'pretzel',
933: 'cheeseburger',
934: 'hotdog, hot dog, red hot',
935: 'mashed potato',
936: 'head cabbage',
937: 'broccoli',
938: 'cauliflower',
939: 'zucchini, courgette',
940: 'spaghetti squash',
941: 'acorn squash',
942: 'butternut squash',
943: 'cucumber, cuke',
944: 'artichoke, globe artichoke',
945: 'bell pepper',
946: 'cardoon',
947: 'mushroom',
948: 'Granny Smith',
949: 'strawberry',
950: 'orange',
951: 'lemon',
952: 'fig',
953: 'pineapple, ananas',
954: 'banana',
955: 'jackfruit, jak, jack',
956: 'custard apple',
957: 'pomegranate',
958: 'hay',
959: 'carbonara',
960: 'chocolate sauce, chocolate syrup',
961: 'dough',
962: 'meat loaf, meatloaf',
963: 'pizza, pizza pie',
964: 'potpie',
965: 'burrito',
966: 'red wine',
967: 'espresso',
968: 'cup',
969: 'eggnog',
970: 'alp',
971: 'bubble',
972: 'cliff, drop, drop-off',
973: 'coral reef',
974: 'geyser',
975: 'lakeside, lakeshore',
976: 'promontory, headland, head, foreland',
977: 'sandbar, sand bar',
978: 'seashore, coast, seacoast, sea-coast',
979: 'valley, vale',
980: 'volcano',
981: 'ballplayer, baseball player',
982: 'groom, bridegroom',
983: 'scuba diver',
984: 'rapeseed',
985: 'daisy',
986: "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
987: 'corn',
988: 'acorn',
989: 'hip, rose hip, rosehip',
990: 'buckeye, horse chestnut, conker',
991: 'coral fungus',
992: 'agaric',
993: 'gyromitra',
994: 'stinkhorn, carrion fungus',
995: 'earthstar',
996: 'hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa',
997: 'bolete',
998: 'ear, spike, capitulum',
999: 'toilet tissue, toilet paper, bathroom tissue'
}
# Load test data and make loaders.
nr_samples = 15
assets = np.load("assets/imagenet_test_set.npy", allow_pickle=True).item()
x_batch = assets["x_batch"][:nr_samples]
y_batch = assets["y_batch"][:nr_samples]
s_batch = assets["s_batch"].reshape(-1, 1, 224, 224)[:nr_samples]
# Plot some inputs!
nr_images = 5
fig, axes = plt.subplots(nrows=1, ncols=nr_images, figsize=(nr_images*3, int(nr_images*2)))
for i in range(nr_images):
image = (np.moveaxis(quantus.denormalise(x_batch[i], mean=np.array([0.485, 0.456, 0.406]),
std=np.array([0.229, 0.224, 0.225])), 0, -1) * 255).astype(np.uint8)
axes[i].imshow(image, vmin=0.0, vmax=1.0, cmap="gray")
axes[i].title.set_text(f"{CLASSES[y_batch[i]][:15]}")
axes[i].axis("off")
plt.show()
Since the focus of this tutorial is not on model training but XAI evalution, we load pre-trained models with torchvision
library.
# Pick your model!
models = torchvision.models.list_models(module=torchvision.models)
models
['alexnet', 'convnext_base', 'convnext_large', 'convnext_small', 'convnext_tiny', 'densenet121', 'densenet161', 'densenet169', 'densenet201', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_v2_l', 'efficientnet_v2_m', 'efficientnet_v2_s', 'googlenet', 'inception_v3', 'maxvit_t', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small', 'regnet_x_16gf', 'regnet_x_1_6gf', 'regnet_x_32gf', 'regnet_x_3_2gf', 'regnet_x_400mf', 'regnet_x_800mf', 'regnet_x_8gf', 'regnet_y_128gf', 'regnet_y_16gf', 'regnet_y_1_6gf', 'regnet_y_32gf', 'regnet_y_3_2gf', 'regnet_y_400mf', 'regnet_y_800mf', 'regnet_y_8gf', 'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext101_64x4d', 'resnext50_32x4d', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0', 'squeezenet1_0', 'squeezenet1_1', 'swin_b', 'swin_s', 'swin_t', 'swin_v2_b', 'swin_v2_s', 'swin_v2_t', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'vit_b_16', 'vit_b_32', 'vit_h_14', 'vit_l_16', 'vit_l_32', 'wide_resnet101_2', 'wide_resnet50_2']
# Load pre-trained model of choice.
model = torchvision.models.resnet18(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth 100%|██████████| 44.7M/44.7M [00:00<00:00, 111MB/s]
To gather more insight into how the model made its prediction, we can apply different explanation methods. There exist multiple ways to generate explanations for neural network models e.g., using captum
, zennit
and tf-explain
libraries.
Quantus is compatible with PyTorch
, offering 20+ XAI methods for that ML framework.
import quantus
# View the XAI methods available for PyTorch users.
quantus.AVAILABLE_XAI_METHODS_CAPTUM
['GradientShap', 'IntegratedGradients', 'DeepLift', 'DeepLiftShap', 'InputXGradient', 'Saliency', 'FeatureAblation', 'Deconvolution', 'FeaturePermutation', 'Lime', 'KernelShap', 'LRP', 'Gradient', 'Occlusion', 'LayerGradCam', 'GuidedGradCam', 'LayerConductance', 'LayerActivation', 'InternalInfluence', 'LayerGradientXActivation', 'Control Var. Sobel Filter', 'Control Var. Constant', 'Control Var. Random Uniform']
In this example, we rely on the quantus.explain
functionality (a simple wrapper around captum
) however use whatever approach or library you'd like to create your explanations.
#@title 3.3.1 Plotting single method
%%capture
def plot_explanation(input_sample: np.array,
pred_name: str,
explanation: np.array,
img_size: int = 28,
normalise: bool = True,
denormalise: bool = False,
method: str = "Gradient"):
"""Plot an explanation for an input sample."""
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 5))
if denormalise:
# ImageNet settings.
image = (np.moveaxis(quantus.denormalise(input_sample,
mean=np.array([0.485, 0.456, 0.406]),
std=np.array([0.229, 0.224, 0.225])), 0, -1) * 255).astype(np.uint8)
else:
image = np.moveaxis(input_sample, 0, 2)
if normalise:
explanation = quantus.normalise_by_max(explanation)
axes[0].imshow(image, vmin=0.0, vmax=1.0)
axes[0].title.set_text(f"Class {pred_name}")
axes[0].axis("off");
attr_ = axes[1].imshow(explanation.reshape(img_size, img_size), cmap="seismic")
fig.colorbar(attr_, fraction=0.05, pad=0.05);
axes[1].title.set_text(method)
axes[1].axis("off")
plt.show()
def get_pred_name(y_pred_id: int, dataset: str = "imagenet"):
if dataset == "imagenet":
return CLASSES[y_pred_id]
return medmnist.INFO[DATA_FLAG]['label'][str(y_pred_id)]
# Get base explanations.
a_grad = quantus.explain(model, x_batch, y_batch)
# Plot a random explanation sample!
index = np.random.randint(0, len(x_batch)-1)
input_sample = x_batch[index]
pred_name = get_pred_name(y_batch[index])
explanation = a_grad[index]
plot_explanation(input_sample, pred_name, explanation, img_size=224, denormalise=True, normalise=False)
# Inspect documentation for the explanation method.
quantus.explain
quantus.functions.explanation_func.explain
def explain(model, inputs, targets, **kwargs) -> np.ndarray
Explain inputs given a model, targets and an explanation method. Expecting inputs to be shaped such as (batch_size, nr_channels, ...) or (batch_size, ..., nr_channels). Parameters ---------- model: torch.nn.Module, tf.keras.Model A model that is used for explanation. inputs: np.ndarray The inputs that ought to be explained. targets: np.ndarray The target lables that should be used in the explanation. kwargs: optional Keyword arguments. Pass as "explain_func_kwargs" dictionary when working with a metric class. Pass as regular kwargs when using the stnad-alone function. xai_lib: string, optional XAI library: captum, tf-explain or zennit. method: string, optional XAI method (used with captum and tf-explain libraries). attributor: string, optional XAI method (used with zennit). xai_lib_kwargs: dictionary, optional Keyword arguments to be passed to the attribution function. softmax: boolean, optional Indicated whether softmax activation in the last layer shall be removed. channel_first: boolean, optional Indicates if the image dimensions are channel first, or channel last. Inferred from the input shape if None. reduce_axes: tuple Indicates the indices of dimensions of the output explanation array to be summed. For example, an input array of shape (8, 28, 28, 3) with keepdims=True and reduce_axes = (-1,) will return an array of shape (8, 28, 28, -1). Passing "()" will keep the original dimensions. keepdims: boolean Indicated if the reduced axes shall be preserved (True) or removed (False). Returns ------- explanation: np.ndarray Returns np.ndarray of same shape as inputs.
# Get Saliency explanations, specify via 'method'.
kwargs = {"method": "Saliency", "xai_lib": "captum"}
a_sal = quantus.explain(model, x_batch, y_batch, **kwargs)
a_sal.shape
(15, 1, 224, 224)
# Prepare dictionary with explanation methods and hyperparameters.
xai_methods_with_kwargs ={
#"Occlusion": {"window": (1, 28, 28)},
"LayerGradCam": {"gc_layer": "list(model.named_modules())[61][1]", "interpolate": (224, 224),},
"Saliency": {},
"GradientShap": {},
"IntegratedGradients": {"n_steps": 5},
}
NORMALISE = False
# Populate explanation in the dictionary.
explanations = {}
for method, kwargs in xai_methods_with_kwargs.items():
a_batch = quantus.explain(model=model,
inputs=x_batch,
targets=y_batch,
**{**{"method": method, "xai_lib": "captum"}, **kwargs})
# Normalise for GradCAM.
if NORMALISE:
if a_batch.min() == 0:
explanations[method] = a_batch/a_batch.max(axis=(1,2), keepdims=True)
else:
# If not normalised, normalize by hand to comparable values [0,1].
explanations[method] = np.abs((a_batch - a_batch.min(axis=(1,2), keepdims=True))/(a_batch.max(axis=(1,2), keepdims=True) -a_batch.min(axis=(1,2), keepdims=True)))
else:
explanations[method] = a_batch
print(f"{method} - {a_batch.shape}")
LayerGradCam - (15, 1, 224, 224) Saliency - (15, 1, 224, 224) GradientShap - (15, 1, 224, 224) IntegratedGradients - (15, 1, 224, 224)
# Or define your own XAI method.
def your_own_random_explainer(model: torch.nn,
inputs: np.array,
targets: np.array,
**kwargs):
# Dummy explanation.
size = kwargs.get("size", (15, 1, 224, 224))
a_batch = np.random.random(size=size)
return a_batch
#@title 3.3.2 Write own FusionGrad explainer
%%capture
import copy
import gc
import numpy as np
import torch
import quantus
from captum.attr import Saliency
import torch
def save_model_state(model):
return {k: v.clone() for k, v in model.state_dict().items()}
def restore_model_state(model, state):
model.load_state_dict(state)
def fusiongrad_explainer(model, inputs, targets, **kwargs) -> np.ndarray:
"""Implementation of FusionGrad by Bykov et al., 2022."""
original_state = save_model_state(model)
# PyTorch and general processing.
device = kwargs.get("device", "cpu")
img_size = kwargs.get("img_size", 224)
nr_channels = kwargs.get("nr_channels", 3)
# Post-processng attribution.
abs = kwargs.get("abs", False)
normalise = kwargs.get("normalise", False)
normalise_func = kwargs.get("normalise_func", quantus.normalise_by_negative)
# FusionGrad specific.
posterior_mean = kwargs.get("posterior_mean", copy.deepcopy(model.to(device).state_dict()))
mean, std = kwargs.get("mean, std", (1.0, 0.75))
sg_mean, sg_std = kwargs.get("sg_mean, sg_std", (0.0, 0.25))
n, m = kwargs.get("n, m", (10, 10))
# Save the posterior mean, copy of the mode.
#posterior_mean = copy.deepcopy(model.to(device).state_dict())
original_parameters = model.state_dict()
def _sample(model, original_parameters, std, distribution=None, noise_type="multiplicative"):
"""Implementation to sample a model."""
# Creates a normal (also called Gaussian) distribution.
distribution = torch.distributions.normal.Normal(loc=torch.as_tensor(mean, dtype=torch.float),
scale=torch.as_tensor(std, dtype=torch.float))
# Load model params.
model_copy = copy.deepcopy(model)
model_copy.load_state_dict(original_parameters)
# If std is not zero, loop over each layer and add Gaussian noise.
if not std == 0.0:
with torch.no_grad():
for layer in model_copy.parameters():
if noise_type == "additive":
layer.add_(distribution.sample(layer.size()).to(layer.device))
elif noise_type == "multiplicative":
layer.mul_(distribution.sample(layer.size()).to(layer.device))
else:
print("Set NoiseGrad attribute 'noise_type' to either 'additive' or 'multiplicative' (str).")
return model_copy
# Set model in evaluate mode.
model.to(device)
model.eval()
if not isinstance(inputs, torch.Tensor):
inputs = (torch.Tensor(inputs).reshape(-1, nr_channels, img_size, img_size,).to(device))
if not isinstance(targets, torch.Tensor):
targets = torch.as_tensor(targets).long().to(device)
assert (
len(np.shape(inputs)) == 4
), "Inputs should be shaped (nr_samples, nr_channels, img_size, img_size) e.g., (1, 3, 224, 224)."
if inputs.shape[0] > 1:
attrs = torch.zeros((m, n, inputs.shape[0], img_size, img_size,))
else:
attrs = torch.zeros((m, n, img_size, img_size))
for i in range(m):
# Sample a model.
model_copy = _sample(model=model, original_parameters=original_parameters, std=std, noise_type="multiplicative")
for j in range(n):
# Add noise to the inputs.
inputs_noisy = inputs + torch.randn_like(inputs) * sg_std + sg_mean
# Compute the Saliency explanation.
attrs[i][j] = Saliency(model_copy).attribute(inputs_noisy, targets, abs=abs).sum(axis=1)
attrs[i][j] = attrs[i][j].reshape(-1, img_size, img_size).cpu().data
# Average over the samples.
attrs = attrs.mean(axis=(0, 1))
attrs = torch.unsqueeze(attrs, 1)
gc.collect()
torch.cuda.empty_cache()
# Add normalisation.
if normalise:
attrs = normalise_func(attrs)
if isinstance(attrs, torch.Tensor):
if attrs.requires_grad:
return attrs.cpu().detach().numpy()
return attrs.cpu().numpy()
restore_model_state(model, original_state)
return attrs
# Compare with your own explainer function.
explanations["FusionGrad"] = fusiongrad_explainer(model=model.cpu(), inputs=x_batch, targets=y_batch, **{"m, n": (10, 10),
"sg_mean, sg_std": (0.0, 0.25),
"mean, std": (1.0, 0.8),
"posterior_mean": copy.deepcopy(model.to(device).state_dict())})
explanations["Random"] = your_own_random_explainer(model=model, inputs=x_batch, targets=y_batch, **{"size": (len(x_batch), 1, 224, 224)})
#@title 3.3.3 Plotting multiple methods
%%capture
def plot_explanation_methods(explanations: dict,
x_batch: np.array,
y_batch: np.array,
method_names: list,
colours: list,
indices: list = [1, 10, 6],
img_size: int = 28) -> None:
# Plotting configs.
plt.rcParams['ytick.left'] = False
plt.rcParams['ytick.labelleft'] = False
plt.rcParams['xtick.bottom'] = False
plt.rcParams['xtick.labelbottom'] = False
# Plot explanations!
index = 1
ncols = 1 + len(explanations)
for index in indices:
fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(15, int(ncols)*3))
for i in range(ncols):
if i == 0:
pred_name = get_pred_name(y_pred_id=y_batch[index])
if len(pred_name) > 30:
pred_name = pred_name[:15] + "\n" + pred_name[15:30] + "\n" + pred_name[30:]
elif len(pred_name) > 15:
pred_name = pred_name[:15] + "\n" + pred_name[15:]
with warnings.catch_warnings():
image = (np.moveaxis(quantus.denormalise(x_batch[index],
mean=np.array([0.485, 0.456, 0.406]),
std=np.array([0.229, 0.224, 0.225])), 0, -1) * 255).astype(np.uint8)
axes[0].imshow(image, vmin=0.0, vmax=1.0)
axes[0].set_title(f"{pred_name.title()}", fontsize=12)
axes[0].axis("off")
else:
axes[i].imshow(explanations[method_names[i-1]][index].reshape(img_size, img_size), cmap="seismic", vmin=-1.0, vmax=1.0)
axes[i].set_title(f"{method_names[i-1]}", fontsize=12)
# Frame configs.
axes[i].xaxis.set_visible([])
axes[i].yaxis.set_visible([])
axes[i].spines["top"].set_color(colours[i-1])
axes[i].spines["bottom"].set_color(colours[i-1])
axes[i].spines["left"].set_color(colours[i-1])
axes[i].spines["right"].set_color(colours[i-1])
axes[i].spines["top"].set_linewidth(5)
axes[i].spines["bottom"].set_linewidth(5)
axes[i].spines["left"].set_linewidth(5)
axes[i].spines["right"].set_linewidth(5)
plt.show()
# Plot explanation methods!
method_names = list(explanations.keys())
colours = random_hex_codes = ['#' + ''.join(np.random.choice(list('0123456789ABCDEF'), size=6)) for _ in range(len(explanations))]
plot_explanation_methods(explanations, x_batch, y_batch, method_names, colours, img_size=224)
Clearly, there is an unintelligibility of visual comparision.
Quantus is a XAI Python open-source framework to evaluate the performance of neural network explanations.
More details can found in at the offical GitHub repository, Getting Started Guide or API documentation.
import quantus
We answer the following research question:
In the following section, we demonstrate how to use Quantus to evaluate the different explanation methods under various explanation qualities — and their underlying metrics. In the following, we describe each of the categories briefly. The direction of the arrow indicates whether higher or lower values are considered better (exceptions within each category exist, so please carefully read the docstrings of each individual metric prior to usage and/or interpretation). For more complete description of the different properties, please see the official Github repository.
# In each category of explanation quality, let's view the available metrics.
for k, v in quantus.AVAILABLE_METRICS.items():
print(k)
for i in v:
print(f"\t• {i}")
Faithfulness • Faithfulness Correlation • Faithfulness Estimate • Pixel-Flipping • Region Segmentation • Monotonicity-Arya • Monotonicity-Nguyen • Selectivity • SensitivityN • IROF • ROAD • Infidelity • Sufficiency Robustness • Continuity Test • Local Lipschitz Estimate • Max-Sensitivity • Avg-Sensitivity • Consistency • Relative Input Stability • Relative Output Stability • Relative Representation Stability Localisation • Pointing Game • Top-K Intersection • Relevance Mass Accuracy • Relevance Rank Accuracy • Attribution Localisation • AUC • Focus Complexity • Sparseness • Complexity • Effective Complexity Randomisation • MPRT • Smooth MPRT • Efficient MPRT • Random Logit Axiomatic • Completeness • NonSensitivity • InputInvariance
We select one single metric within the complexity category of explanation quality. Sparseness (Chalasani et al., 2020) is quantified using the Gini Index applied to the vector of the absolute values of attributions $\hat{\mathbf{e}}$ of length $D$.
$$\Psi_\text{SP}(\hat{\mathbf{e}}) = \frac{\sum_{i=1}^{D}(2 i-D-1) \hat{\mathbf{e}}_{i}}{D \sum_{i=1}^{D} \hat{\mathbf{e}}_{i}}$$# Let's try initialising one Complexity metric, called Sparseness.
quantus.PixelFlipping().get_params
{'abs': False, 'normalise': True, 'return_aggregate': False, 'aggregate_func': <function mean at 0x7870f2d29e30>, 'normalise_func': <function quantus.functions.normalise_func.normalise_by_max(a: numpy.ndarray, normalise_axes: Optional[Sequence[int]] = None) -> numpy.ndarray>, 'normalise_func_kwargs': {}, '_disable_warnings': False, '_display_progressbar': False, 'a_axes': None, 'features_in_step': 1, 'return_auc_per_sample': False, 'perturb_func': functools.partial(<function baseline_replacement_by_indices at 0x786f805155a0>, perturb_baseline='black')}
# Let's try initialising one Complexity metric, called Sparseness.
quantus.Sparseness().get_params
{'abs': True, 'normalise': True, 'return_aggregate': False, 'aggregate_func': <function mean at 0x7870f2d29e30>, 'normalise_func': <function quantus.functions.normalise_func.normalise_by_max(a: numpy.ndarray, normalise_axes: Optional[Sequence[int]] = None) -> numpy.ndarray>, 'normalise_func_kwargs': {}, '_disable_warnings': False, '_display_progressbar': False, 'a_axes': None}
As a starter, we evaluate Saliency
explanation (Mørch et al., 1995;
Baehrens et al., 2010)
# Alternative 1. Evaluate the Saliency explanations in a one-liner - by calling the intialised metric.
quantus.Sparseness()(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=explanations["Saliency"])
[0.4388029590304715, 0.47652868406704785, 0.41440676591897163, 0.46697679949295834, 0.4414132058788879, 0.43754174893746944, 0.45429935647536707, 0.49777907790547427, 0.5283665277255211, 0.4428323014199678, 0.4146280311647153, 0.5192187669945634, 0.49564300585363763, 0.4010246933427686, 0.45522917726154716]
# Change some hyperparameters, get an aggregate score over several test samples.
quantus.Sparseness(return_aggregate=True, disable_warnings=True)(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=explanations["Saliency"])
[0.4589794067646247]
We evaluate GradientShap
explanation (Lundberg and Lee, 2017).
# Change the explanation method to evaluate Sparseness on GradientShap.
quantus.Sparseness(return_aggregate=True, disable_warnings=True)(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=explanations["GradientShap"])
[0.595142746343291]
We evaluate FusionGrad
FusionGrad (Bykov et al., 2021 explanation.
# Change the explanation method to evaluate Sparseness on FusionGrad.
quantus.Sparseness(return_aggregate=True, disable_warnings=True)(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=explanations["FusionGrad"])
[0.48891676021172176]
# Score all methods iteratively.
for method, attr in explanations.items():
metric = quantus.Sparseness(return_aggregate=False, disable_warnings=True)
scores = metric(model=model, x_batch=x_batch, y_batch=y_batch, a_batch=attr)
print(f" {method} - {np.mean(scores):.2f} ({np.std(scores):.2f})")
LayerGradCam - 0.31 (0.06) Saliency - 0.46 (0.04) GradientShap - 0.60 (0.04) IntegratedGradients - 0.59 (0.04) FusionGrad - 0.49 (0.01) Random - 0.33 (0.00)
To structure the analysis a bit futher, you can leverage the built-in functionality of quantus.evaluate()
.
In the following, we use Quantus to quantiatively assess the different explanation methods on various evaluation criteria.
# Initialise the Quantus evaluation metrics.
metrics = {
"Robustness": quantus.MaxSensitivity(
nr_samples=10,
lower_bound=0.2,
norm_numerator=quantus.norm_func.fro_norm,
norm_denominator=quantus.norm_func.fro_norm,
perturb_func=quantus.perturb_func.uniform_noise,
similarity_func=quantus.similarity_func.difference,
abs=False,
normalise=True,
normalise_func=quantus.normalise_by_max,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Faithfulness": quantus.FaithfulnessCorrelation(
nr_runs=10,
subset_size=224,
perturb_baseline="black",
perturb_func=quantus.baseline_replacement_by_indices,
similarity_func=quantus.similarity_func.correlation_pearson,
abs=True,
normalise=True,
normalise_func=quantus.normalise_by_max,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Localisation": quantus.RelevanceRankAccuracy(
abs=False,
normalise=True,
normalise_func=quantus.normalise_by_max,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Complexity": quantus.Sparseness(
abs=True,
normalise=True,
normalise_func=quantus.normalise_by_max,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
"Sensitivity": quantus.EfficientMPRT(
similarity_func=quantus.similarity_func.correlation_spearman,
abs=False,
normalise=True,
normalise_func=quantus.normalise_by_max,
aggregate_func=np.mean,
return_aggregate=True,
disable_warnings=True,
),
}
quantus.evaluate
quantus.evaluation.evaluate
def evaluate(metrics: Dict, xai_methods: Union[Dict[str, Callable], Dict[str, Dict], Dict[str, np.ndarray]], model: ModelInterface, x_batch: np.ndarray, y_batch: np.ndarray, s_batch: Union[np.ndarray, None]=None, agg_func: Callable=lambda x: x, explain_func_kwargs: Optional[dict]=None, call_kwargs: Union[Dict, Dict[str, Dict]]=None, return_as_df: Optional[bool]=None, verbose: Optional[bool]=None, progress: Optional[bool]=None, *args, **kwargs) -> Optional[dict]
Evaluate different explanation methods using specified metrics. Parameters ---------- metrics : dict A dictionary of initialized evaluation metrics. See quantus.AVAILABLE_METRICS. Example: {'Robustness': quantus.MaxSensitivity(), 'Faithfulness': quantus.PixelFlipping()} xai_methods : dict A dictionary specifying the explanation methods to evaluate, which can be structured in three ways: 1) Dict[str, Dict] for built-in Quantus methods (using quantus.explain): Example: xai_methods = { 'IntegratedGradients': { 'n_steps': 10, 'xai_lib': 'captum' }, 'Saliency': { 'xai_lib': 'captum' } } - See quantus.AVAILABLE_XAI_METHODS_CAPTUM for supported captum methods. - See quantus.AVAILABLE_XAI_METHODS_TF for supported tensorflow methods. - See https://github.com/chr5tphr/zennit for supported zennit methods. - Read more about the explanation function arguments here: <https://quantus.readthedocs.io/en/latest/docs_api/quantus.functions.explanation_func.html#quantus.functions.explanation_func.explain> 2) Dict[str, Callable] for custom methods: Example: xai_methods = { 'custom_own_xai_method': custom_explain_function } or ai_methods = {"InputXGradient": { "explain_func": quantus.explain, "explain_func_kwargs": {}, }} - Here, you can provide your own callable that mirrors the input and outputs of the quantus.explain() method. 3) Dict[str, np.ndarray] for pre-calculated attributions: Example: xai_methods = { 'LIME': precomputed_numpy_lime_attributions, 'GradientShap': precomputed_numpy_shap_attributions } - Note that some Quantus metrics, e.g., quantus.MaxSensitivity() within the robustness category, includes "re-explaning" the input and output pair as a part of the evaluation metric logic. If you include such metrics in the quantus.evaluate(), this option will not be possible. It is also possible to pass a combination of the above. >>> xai_methods = { >>> 'IntegratedGradients': { >>> 'n_steps': 10, >>> 'xai_lib': 'captum' >>> }, >>> 'Saliency': { >>> 'xai_lib': 'captum' >>> }, >>> 'custom_own_xai_method': custom_explain_function, >>> 'LIME': precomputed_numpy_lime_attributions, >>> 'GradientShap': precomputed_numpy_shap_attributions >>> } model: Union[torch.nn.Module, tf.keras.Model] A torch or tensorflow model that is subject to explanation. x_batch: np.ndarray A np.ndarray containing the input data to be explained. y_batch: np.ndarray A np.ndarray containing the output labels corresponding to x_batch. s_batch: np.ndarray, optional A np.ndarray containing segmentation masks that match the input. agg_func: Callable Indicates how to aggregate scores, e.g., pass np.mean. explain_func_kwargs: dict, optional Keyword arguments to be passed to explain_func on call. Pass None if using Dict[str, Dict] type for xai_methods. call_kwargs: Dict[str, Dict] Keyword arguments for the call of the metrics. Keys are names for argument sets, and values are argument dictionaries. verbose: optional, bool Indicates whether to print evaluation progress. progress: optional, bool Deprecated. Indicates whether to print evaluation progress. Use verbose instead. return_as_df: optional, bool Indicates whether to return the results as a pd.DataFrame. Only works if call_kwargs is not passed. args: optional Deprecated arguments for the call. kwargs: optional Deprecated keyword arguments for the call of the metrics. Returns ------- results: dict A dictionary with the evaluation results.
# Re-define XAI methods to score.
xai_methods_with_kwargs ={
#"Occlusion": {"window": (1, 28, 28)},
"LayerGradCam": {"gc_layer": "list(model.named_modules())[61][1]", "interpolate": (224, 224),},
"Saliency": {},
"GradientShap": {},
"IntegratedGradients": {"n_steps": 5},
"FusionGrad": fusiongrad_explainer,
"Random": your_own_random_explainer,
}
# Run full quantification analysis!
load_obj = True
if not load_obj:
# Evaluate XAI methods.
results = quantus.evaluate(metrics=metrics,
xai_methods=xai_methods_with_kwargs,
model=model.cpu(),
x_batch=x_batch,
y_batch=y_batch,
s_batch=s_batch,
agg_func=np.mean,
explain_func_kwargs=None,
call_kwargs=None,
return_as_df=False,
verbose=True)
else:
# Retrieve stored data.
results = {
"LayerGradCam": {
"Robustness": 6.385364405314127,
"Faithfulness": 0.012830701450853242,
"Localisation": 0.5945537119866623,
"Complexity": 0.534340942620482,
"Sensitivity": -0.011923145576686991,
},
"Saliency": {
"Robustness": 0.8709124883015951,
"Faithfulness": 0.11204023012353007,
"Localisation": 0.5945537119866623,
"Complexity": 0.534340942620482,
"Sensitivity": 0.10401733999005232,
},
"GradientShap": {
"Robustness": 1.8177134116490683,
"Faithfulness": 0.12337100760315674,
"Localisation": 0.5945537119866623,
"Complexity": 0.534340942620482,
"Sensitivity": 0.03112477971924486,
},
"IntegratedGradients": {
"Robustness": 1.1781942049662273,
"Faithfulness": 0.1820758061666603,
"Localisation": 0.5945537119866623,
"Complexity": 0.534340942620482,
"Sensitivity": 0.017489182288717948,
},
"FusionGrad": {
"Robustness": 1.395110293229421,
"Faithfulness": 0.1478673954202166,
"Localisation": 0.5614189583462068,
"Complexity": 0.4727508018004431,
"Sensitivity": -0.005249520387875365,
},
"Random": {
"Robustness": 0.7098723664216617,
"Faithfulness": 0.13276692371798132,
"Localisation": 0.5101552561594804,
"Complexity": 0.3331452921782692,
"Sensitivity": -4.37913867605053e-06,
},
}
results
{'LayerGradCam': {'Robustness': 6.385364405314127, 'Faithfulness': 0.012830701450853242, 'Localisation': 0.5945537119866623, 'Complexity': 0.534340942620482, 'Sensitivity': -0.011923145576686991}, 'Saliency': {'Robustness': 0.8709124883015951, 'Faithfulness': 0.11204023012353007, 'Localisation': 0.5945537119866623, 'Complexity': 0.534340942620482, 'Sensitivity': 0.10401733999005232}, 'GradientShap': {'Robustness': 1.8177134116490683, 'Faithfulness': 0.12337100760315674, 'Localisation': 0.5945537119866623, 'Complexity': 0.534340942620482, 'Sensitivity': 0.03112477971924486}, 'IntegratedGradients': {'Robustness': 1.1781942049662273, 'Faithfulness': 0.1820758061666603, 'Localisation': 0.5945537119866623, 'Complexity': 0.534340942620482, 'Sensitivity': 0.017489182288717948}, 'FusionGrad': {'Robustness': 1.395110293229421, 'Faithfulness': 0.1478673954202166, 'Localisation': 0.5614189583462068, 'Complexity': 0.4727508018004431, 'Sensitivity': -0.005249520387875365}, 'Random': {'Robustness': 0.7098723664216617, 'Faithfulness': 0.13276692371798132, 'Localisation': 0.5101552561594804, 'Complexity': 0.3331452921782692, 'Sensitivity': -4.37913867605053e-06}}
Postprocessing of scores, computing how the different explanation methods rank across criteria.
# Postprocessing of scores: to get how the different explanation methods rank across criteria.
results_agg = {}
for method in xai_methods_with_kwargs:
results_agg[method] = {}
for metric, metric_func in metrics.items():
results_agg[method][metric] = np.mean(results[method][metric])
df = pd.DataFrame.from_dict(results_agg)
df = df.T.abs()
df
Robustness | Faithfulness | Localisation | Complexity | Sensitivity | |
---|---|---|---|---|---|
LayerGradCam | 6.385364 | 0.012831 | 0.594554 | 0.534341 | 0.011923 |
Saliency | 0.870912 | 0.112040 | 0.594554 | 0.534341 | 0.104017 |
GradientShap | 1.817713 | 0.123371 | 0.594554 | 0.534341 | 0.031125 |
IntegratedGradients | 1.178194 | 0.182076 | 0.594554 | 0.534341 | 0.017489 |
FusionGrad | 1.395110 | 0.147867 | 0.561419 | 0.472751 | 0.005250 |
Random | 0.709872 | 0.132767 | 0.510155 | 0.333145 | 0.000004 |
To compare the different XAI methods, we normalise the metric scores between $[0, 1]$ and rank the scores from lowest to highest (i.e. the highest rank corresponds to best performance).
# Take inverse ranking for Robustness, since lower is better.
df_normalised = df.loc[:, ~df.columns.isin(['Robustness'])].apply(lambda x: x / x.max())
df_normalised["Robustness"] = df["Robustness"].min()/df["Robustness"].values
df_normalised_rank = df_normalised.rank()
df_normalised_rank
Faithfulness | Localisation | Complexity | Sensitivity | Robustness | |
---|---|---|---|---|---|
LayerGradCam | 1.0 | 4.5 | 4.5 | 3.0 | 1.0 |
Saliency | 2.0 | 4.5 | 4.5 | 6.0 | 5.0 |
GradientShap | 3.0 | 4.5 | 4.5 | 5.0 | 2.0 |
IntegratedGradients | 6.0 | 4.5 | 4.5 | 4.0 | 4.0 |
FusionGrad | 5.0 | 2.0 | 2.0 | 2.0 | 3.0 |
Random | 4.0 | 1.0 | 1.0 | 1.0 | 6.0 |
#@title 4.2.1 Create spyder plot
%%capture
# Plotting specifics.
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
# Plotting configs.
sns.set(font_scale=1.5)
plt.style.use('seaborn-white')
plt.rcParams['ytick.labelleft'] = True
plt.rcParams['xtick.labelbottom'] = True
include_titles = True
include_legend = True
# Source code: https://matplotlib.org/stable/gallery/specialty_plots/radar_chart.html.
def spyder_plot(num_vars, frame='circle'):
"""Create a radar chart with `num_vars` axes.
This function creates a RadarAxes projection and registers it.
Parameters
----------
num_vars : int
Number of variables for radar chart.
frame : {'circle' | 'polygon'}
Shape of frame surrounding axes.
"""
# calculate evenly-spaced axis angles
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
class RadarAxes(PolarAxes):
name = 'radar'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# rotate plot such that the first axis is at the top
self.set_theta_zero_location('N')
def fill(self, *args, closed=True, **kwargs):
"""Override fill so that line is closed by default."""
return super().fill(closed=closed, *args, **kwargs)
def plot(self, *args, **kwargs):
"""Override plot so that line is closed by default."""
lines = super().plot(*args, **kwargs)
for line in lines:
self._close_line(line)
def _close_line(self, line):
x, y = line.get_data()
# FIXME: markers at x[0], y[0] get doubled-up
if x[0] != x[-1]:
x = np.concatenate((x, [x[0]]))
y = np.concatenate((y, [y[0]]))
line.set_data(x, y)
def set_varlabels(self, labels, angles=None):
self.set_thetagrids(angles=np.degrees(theta), labels=labels)
def _gen_axes_patch(self):
# The Axes patch must be centered at (0.5, 0.5) and of radius 0.5
# in axes coordinates.
if frame == 'circle':
return Circle((0.5, 0.5), 0.5)
elif frame == 'polygon':
return RegularPolygon((0.5, 0.5), num_vars,
radius=.5, edgecolor="k")
else:
raise ValueError("unknown value for 'frame': %s" % frame)
def draw(self, renderer):
""" Draw. If frame is polygon, make gridlines polygon-shaped."""
if frame == 'polygon':
gridlines = self.yaxis.get_gridlines()
for gl in gridlines:
gl.get_path()._interpolation_steps = num_vars
super().draw(renderer)
def _gen_axes_spines(self):
if frame == 'circle':
return super()._gen_axes_spines()
elif frame == 'polygon':
# spine_type must be 'left'/'right'/'top'/'bottom'/'circle'.
spine = Spine(axes=self,
spine_type='circle',
path=Path.unit_regular_polygon(num_vars))
# unit_regular_polygon gives a polygon of radius 1 centered at
# (0, 0) but we want a polygon of radius 0.5 centered at (0.5,
# 0.5) in axes coordinates.
spine.set_transform(Affine2D().scale(.5).translate(.5, .5)
+ self.transAxes)
return {'polar': spine}
else:
raise ValueError("unknown value for 'frame': %s" % frame)
register_projection(RadarAxes)
return theta
XAI evaluation can help researchers establish appropriate explanation methods for a specific tasks. These performance measures can help validate network models and prediction as well as insights inferred from explanations.
# Make spyder graph!
data = [df_normalised_rank.columns.values, (df_normalised_rank.to_numpy())]
theta = spyder_plot(len(data[0]), frame='polygon')
spoke_labels = data.pop(0)
fig, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(projection='radar'))
fig.subplots_adjust(top=0.85, bottom=0.05)
for i, (d, method) in enumerate(zip(data[0], xai_methods_with_kwargs)):
line = ax.plot(theta, d, label=method, color=colours[i], linewidth=5.0)
ax.fill(theta, d, alpha=0.15)
# Set lables.
if include_titles:
ax.set_varlabels(labels=['Faithfulness',
'Localisation',
'Complexity',
'Sensitivity',
'Robustness']) #
else:
ax.set_varlabels(labels=[])
ax.set_rgrids(np.arange(0, df_normalised_rank.values.max() + 0.5), labels=[])
# Set a title.
ax.set_title("Quantus: Summary of Quantification", position=(0.5, 1.1), ha='center', fontsize=15)
# Put a legend to the right of the current axis.
if include_legend:
ax.legend(loc='upper left', bbox_to_anchor=(1, 0.5))
plt.tight_layout()
No clear winner. Many explanations score poorly in absolute terms.
Limitations. XAI evaluation faces certain limitations due to the absence of a reliable ground-truth, which means the evaluation metrics provided can only assess crucial properties that a valid explanation must possess, and cannot provide a complete validation. While the evaluation of XAI methods is a rapidly evolving field, the metrics offered by the Quantus library have certain limitations, such as relying on perturbing the input which may lead to the creation of out-of-distribution inputs. It should be noted that evaluating explanation methods using quantification analysis does not guarantee the theoretical soundness or statistical validity of the methods. Therefore, when using the Quantus library for XAI method selection, it is essential to supplement the results with theoretical considerations.
MetaQuantus is an open-source, development tool for XAI researchers and Machine Learning (ML) practitioners to verify and benchmark newly constructed metrics (i.e., ``quality estimators''). It includes:
ModelPerturbationTest
and InputPertubrationTest
that can be applied to various metricsMore details can found in at the offical GitHub repository.
import metaquantus
We answer the following research question:
# Re-load data back.
assets = np.load("assets/imagenet_test_set.npy", allow_pickle=True).item()
x_batch = assets["x_batch"][:nr_samples]
y_batch = assets["y_batch"][:nr_samples]
s_batch = assets["s_batch"].reshape(-1, 1, 224, 224)[:nr_samples]
# Load pre-trained model of choice.
model = torchvision.models.resnet18(pretrained=True)
# Define a set of estimators using Quantus, a broader set of Localisation metrics.
estimators_localisation = {
"Localisation": {
"Pointing-Game": {
"init":
quantus.PointingGame(
abs=False,
normalise=True,
normalise_func=quantus.normalise_func.normalise_by_max,
normalise_func_kwargs={},
return_aggregate=False,
aggregate_func=np.mean,
disable_warnings=True,
), "score_direction": "higher"},
"Top-K Intersection": {
"init":
quantus.TopKIntersection(
k=10,
abs=False,
normalise=True,
normalise_func=quantus.normalise_func.normalise_by_max,
normalise_func_kwargs={},
return_aggregate=False,
aggregate_func=np.mean,
disable_warnings=True,
), "score_direction": "higher"},
"Relevance Rank Accuracy": {
"init":
quantus.RelevanceRankAccuracy(
abs=False,
normalise=True,
normalise_func=quantus.normalise_func.normalise_by_max,
normalise_func_kwargs={},
return_aggregate=False,
aggregate_func=np.mean,
disable_warnings=True,
), "score_direction": "higher"},
"Relevance Mass Accuracy": {
"init":
quantus.RelevanceMassAccuracy(
abs=False,
normalise=True,
normalise_func=quantus.normalise_func.normalise_by_max,
normalise_func_kwargs={},
return_aggregate=False,
aggregate_func=np.mean,
disable_warnings=True,
),
"score_direction": "higher",
},
}
}
We define the Input Perturbation Test and Model Perturbation Test in order to evaluate the estimators.
# Define test suite.
test_suite = {
"Model Resilience Test": metaquantus.ModelPerturbationTest(
**{
"noise_type": "multiplicative",
"mean": 1.0,
"std": 0.001,
"type": "Resilience",
}
),
"Model Adversary Test": metaquantus.ModelPerturbationTest(
**{
"noise_type": "multiplicative",
"mean": 1.0,
"std": 0.5,
"type": "Adversary",
}
),
"Input Resilience Test": metaquantus.InputPerturbationTest(
**{
"noise": 0.001,
"type": "Resilience",
}
),
"Input Adversary Test": metaquantus.InputPerturbationTest(
**{
"noise": 3.0,
"type": "Adversary",
}
),
}
load_obj = True
if not load_obj:
# Set configs.
iters = 5
K = 10
# Define the meta-evaluation exercise.
meta_evaluator = metaquantus.MetaEvaluation(
test_suite=test_suite,
xai_methods={"Saliency": {}, "GradientShap": {}},
iterations=iters,
nr_perturbations=K,
write_to_file=False,
)
# Collect the settings for the dataset.
dataset_settings = {}
dataset_settings["ImageNet"] = {
"x_batch": x_batch,
"y_batch": y_batch,
"s_batch": s_batch,
"models": {
"ResNet18": model,
},
"gc_layers": {
"ResNet18": "list(model.named_modules())[61][1]",
},
"estimator_kwargs": {
"num_classes": 1000,
"img_size": 224,
"features": 224 * 4,
"percentage": 0.1,
"nr_channels": 3,
"patch_size": 224 * 2,
"perturb_baseline": "uniform",
},
}
# Benchmark localisation metrics, using the intialised meta-evaluator.
benchmark = metaquantus.MetaEvaluationBenchmarking(
master=meta_evaluator,
estimators=estimators_localisation,
experimental_settings=dataset_settings,
write_to_file=False,
keep_results=True,
channel_first=True,
softmax=False,
device=device,
)()
else:
!gdown https://drive.google.com/drive/folders/1uOjuMtbNkvLPqXrbozVjsDNeCxZ6peqs --folder --quiet
!ls
# Load the benchmarking data for each model.
dataset_name = "ImageNet"
f_loc = [f.split("transformer_data/")[1] for f in glob.glob("transformer_data/*") if "localisation" in f]
batches = len(f_loc)
benchmarks_loc = {}
for model in ["ResNet18"]:
benchmarks_loc[model] = {}
for batch, f in enumerate(f_loc):
benchmarks_loc[model][batch] = metaquantus.load_obj("transformer_data/", fname=f"{f}", use_json=True)[dataset_name][model]
assets sample_data transformer_data
#@title 5.3.1 Plotting functionality
%%capture
from typing import Dict
def plot_multiple_models_estimator_area(
benchmarks: Dict,
estimators: Dict,
dataset_name: str,
colours: Dict,
save: bool,
path: str,
average_over: list = ["Model", "Input"],
**kwargs
) -> None:
"""
Plot the outcome of the benchmarking exercise for different models.
Parameters
----------
benchmark: dict of dicts
The benchmarking data, keys are the model names.
estimators: dict
The estimators used in the experiment.
dataset_name: str
The name of the dataset.
colours: dict
Dictionary of colours, based on the metrics.
save: boolean
Indicates if plots should be saved.
path: str
The path for saving the plot.
average_over: list
A list of spaces to average over.
kwargs: dict
A dict with plotting kwargs.
Returns
-------
None
"""
n_rows = kwargs.get("n_rows", 2)
n_cols = kwargs.get("n_cols", 5)
batches = kwargs.get("batches", 5)
figsize = kwargs.get("figsize", (20, 8))
fig, axs = plt.subplots(n_rows, n_cols, sharex=True, figsize=figsize)
models = list(benchmarks.keys())
metrics = list(estimators.values())[0]
estimator_category = list(estimators.keys())[0]
for mx1, model_name in enumerate(models):
for ex1, estimator_name in enumerate(metrics):
mc_scores = []
for px, perturbation_type in enumerate(["Input", "Model"]):
scores = {"IAC_NR": [], "IAC_AR": [], "IEC_NR": [], "IEC_AR": []}
for batch in range(batches):
# Collect scores.
scores["IAC_NR"].append(np.array(
benchmarks[model_name][batch][estimator_category][estimator_name][
"results_consistency_scores"
][perturbation_type]["intra_scores_res"]
))
scores["IAC_AR"].append(np.array(
benchmarks[model_name][batch][estimator_category][estimator_name][
"results_consistency_scores"
][perturbation_type]["intra_scores_adv"]
))
scores["IEC_NR"].append(np.array(
benchmarks[model_name][batch][estimator_category][estimator_name][
"results_consistency_scores"
][perturbation_type]["inter_scores_res"]
))
scores["IEC_AR"].append(np.array(
benchmarks[model_name][batch][estimator_category][estimator_name][
"results_consistency_scores"
][perturbation_type]["inter_scores_adv"]
))
for k, v in scores.items():
scores[k] = np.array(scores[k]).flatten()
# Set values for m* and the actual values by the estimator.
X_gt = [-1, 0, 1, 0]
Y_gt = [0, 1, 0, -1]
X_area = [-scores["IAC_AR"].mean(), 0, scores["IEC_AR"].mean(), 0]
Y_area = [0, scores["IAC_NR"].mean(), 0, -scores["IEC_NR"].mean()]
# Set the spaces to average the MC value over.
if perturbation_type in average_over:
mc_score = np.mean(
[
scores["IAC_NR"].mean(),
scores["IEC_NR"].mean(),
scores["IAC_AR"].mean(),
scores["IEC_AR"].mean(),
]
)
mc_scores.append(mc_score)
if perturbation_type == "Input":
axs[ex1].fill(
X_area,
Y_area,
color=colours[estimator_name],
alpha=0.75,
label=perturbation_type,
edgecolor="black",
)
else:
axs[ex1].fill(
X_area,
Y_area,
color=colours[estimator_name],
alpha=0.5,
label=perturbation_type,
hatch="/",
edgecolor="black",
)
# Plot m*.
if px == 1:
axs[ex1].fill(
X_gt, Y_gt, color="black", alpha=0.075, label="m*"
)
# Annotate the labels.
axs[ex1].annotate("${IAC}_{AR}$", (-1, 0), fontsize=12)
axs[ex1].annotate("${IAC}_{NR}$", (-0.2, 0.8), fontsize=12)
axs[ex1].annotate("${IEC}_{AR}$", (0.7, 0), fontsize=12)
axs[ex1].annotate("${IEC}_{NR}$", (-0.2, -0.8), fontsize=12)
# Labels.
axs[ex1].set_xticklabels(
["", "1", "0.5", "0", "0.5", "1"], fontsize=14
)
axs[ex1].set_yticklabels(
["", "1", "", "0.5", "", "0", "", "0.5", "", "1", ""], fontsize=14
)
if ex1 == 0:
axs[ex1].set_ylabel(model_name, fontsize=14)
if estimator_name == "Model Parameter Randomisation Test":
estimator_name = "Model Parameter Random."
# Title and grids.
axs[ex1].set_title(
f"{estimator_name} ({np.array(mc_scores).flatten().mean():.4f})",
fontsize=15,
)
axs[ex1].grid()
axs[ex1].legend(loc="upper left")
plt.grid()
plt.tight_layout()
if save:
plt.savefig(path + "plots/" + f"full_area_graph_{estimator_category}_{dataset_name}_multiple_models.png", dpi=500)
plt.show()
# Plotting settings.
estimators = {"Localisation": ["Pointing-Game", "Top-K Intersection", "Relevance Mass Accuracy", "Relevance Rank Accuracy"],}
colours = {
'Pointing-Game': "#b66a50",
'Top-K Intersection': "#9dbcd4",
'Relevance Mass Accuracy': "#7f7053",
'Relevance Rank Accuracy': "#8fb67b",
}
kwargs = {"n_rows": 1, "n_cols": len(list(estimators.values())[0]), "batches": batches}
kwargs["figsize"] = (kwargs["n_cols"]*4, kwargs["n_rows"]*4)
# Plot!
plot_multiple_models_estimator_area(benchmarks=benchmarks_loc, estimators=estimators, dataset_name=dataset_name, colours=colours, save=False, path="", **kwargs);
We will investigate how much different parameters influence the evaluation outcome, i.e., how different explanations methods rank.
We use Faithfulness Correlation by Bhatt et al., 2020 for this example.
First, we need to re-load a dataset and model.
import quantus
We answer the following research question:
# Load vision medical dataset.
import medmnist
train_set, test_set = medmnist.DermaMNIST(split="train", download=True), medmnist.DermaMNIST(split="test", download=True)
print(f"\n{train_set}")
Downloading https://zenodo.org/records/10519652/files/dermamnist.npz?download=1 to /root/.medmnist/dermamnist.npz
100%|██████████| 19725078/19725078 [00:45<00:00, 436630.96it/s]
Using downloaded and verified file: /root/.medmnist/dermamnist.npz Dataset DermaMNIST of size 28 (dermamnist) Number of datapoints: 7007 Root location: /root/.medmnist Split: train Task: multi-class Number of channels: 3 Meaning of labels: {'0': 'actinic keratoses and intraepithelial carcinoma', '1': 'basal cell carcinoma', '2': 'benign keratosis-like lesions', '3': 'dermatofibroma', '4': 'melanoma', '5': 'melanocytic nevi', '6': 'vascular lesions'} Number of samples: {'train': 7007, 'val': 1003, 'test': 2005} Description: The DermaMNIST is based on the HAM10000, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. The dataset consists of 10,015 dermatoscopic images categorized as 7 different diseases, formulized as a multi-class classification task. We split the images into training, validation and test set with a ratio of 7:1:2. The source images of 3×600×450 are resized into 3×28×28. License: CC BY-NC 4.0
# Visualise some data.
train_set.montage(length=5)
# Inspect shapes.
x, y = train_set[0]
np.shape(x), np.shape(y)
((28, 28, 3), (1,))
# @title 6.2.1 Supporting code
%%capture
import torchvision.transforms as transforms
import torch.utils.data as data
# Define a simple black-box CNN model.
class BlackBoxModel(torch.nn.Module):
def __init__(self, in_channels, num_classes):
super(BlackBoxModel, self).__init__()
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels, 16, kernel_size=3),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU())
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(16, 16, kernel_size=3),
torch.nn.BatchNorm2d(16),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2))
self.layer3 = torch.nn.Sequential(
torch.nn.Conv2d(16, 64, kernel_size=3),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU())
self.layer4 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, kernel_size=3),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU())
self.layer5 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, kernel_size=3, padding=1),
torch.nn.BatchNorm2d(64),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = torch.nn.Sequential(
torch.nn.Linear(64 * 4 * 4, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, num_classes))
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def train_model(model, optimizer, criterion):
# Classic training with torch; forward and backward and optimize.
for epoch in range(EPOCHS):
train_correct = 0
train_total = 0
test_correct = 0
test_total = 0
model.train()
for inputs, targets in tqdm.tqdm(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
if TASK == 'multi-label, binary-class':
targets = targets.to(torch.float32)
loss = criterion(outputs, targets)
else:
targets = targets.squeeze().long()
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
return model
def evaluate_model(split: str) -> None:
model.eval()
y_true = torch.tensor([])
y_score = torch.tensor([])
data_loader = train_loader_at_eval if split == 'train' else test_loader
with torch.no_grad():
for inputs, targets in data_loader:
outputs = model(inputs)
if TASK == 'multi-label, binary-class':
targets = targets.to(torch.float32)
outputs = outputs.softmax(dim=-1)
else:
targets = targets.squeeze().long()
outputs = outputs.softmax(dim=-1)
targets = targets.float().resize_(len(targets), 1)
y_true = torch.cat((y_true, targets), 0)
y_score = torch.cat((y_score, outputs), 0)
y_true = y_true.numpy()
y_score = y_score.detach().numpy()
evaluator = medmnist.Evaluator(DATA_FLAG, split)
metrics = evaluator.evaluate(y_score)
print('%s Model performance AUC: %.3f ACC: %.3f' % (split, *metrics))
# Hyperparams for dataset.
DATA_FLAG = 'dermamnist'
EPOCHS = 10
BATCH_SIZE = 64
TASK = medmnist.INFO[DATA_FLAG]['task']
data_class = getattr(medmnist, medmnist.INFO[DATA_FLAG]['python_class'])
# Preprocessing.
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5], std=[.5])
])
# Load the data info.
train_dataset = data_class(split='train', transform=transform, download=True)
test_dataset = data_class(split='test', transform=transform, download=True)
pil_dataset = data_class(split='train', download=True)
# Load into dataloader.
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
# Hyperparams for model.
LR = 0.001
N_CHANNELS = medmnist.INFO[DATA_FLAG]['n_channels']
N_CLASSES = len(medmnist.INFO[DATA_FLAG]['label'])
# Load model and optimiser.
network = BlackBoxModel(in_channels=N_CHANNELS, num_classes=N_CLASSES)
# Define optimizer and loss function.
optimizer = torch.optim.SGD(network.parameters(), lr=LR, momentum=0.9)
if TASK == "multi-label, binary-class":
criterion = torch.nn.BCEWithLogitsLoss()
else:
criterion = torch.nn.CrossEntropyLoss()
# Train the model.
model = train_model(network, optimizer, criterion)
# Evaluate the model.
evaluate_model('train')
evaluate_model('test')
Using downloaded and verified file: /root/.medmnist/dermamnist.npz Using downloaded and verified file: /root/.medmnist/dermamnist.npz Using downloaded and verified file: /root/.medmnist/dermamnist.npz
100%|██████████| 110/110 [00:04<00:00, 22.75it/s] 100%|██████████| 110/110 [00:04<00:00, 23.07it/s] 100%|██████████| 110/110 [00:05<00:00, 21.04it/s] 100%|██████████| 110/110 [00:04<00:00, 23.28it/s] 100%|██████████| 110/110 [00:05<00:00, 21.10it/s] 100%|██████████| 110/110 [00:04<00:00, 23.17it/s] 100%|██████████| 110/110 [00:04<00:00, 22.44it/s] 100%|██████████| 110/110 [00:05<00:00, 21.68it/s] 100%|██████████| 110/110 [00:04<00:00, 23.06it/s] 100%|██████████| 110/110 [00:05<00:00, 21.25it/s]
train Model performance AUC: 0.911 ACC: 0.740 test Model performance AUC: 0.892 ACC: 0.721
# Prepare a test batch.
nr_samples = 100
x_batch = []
y_batch = []
for i in range(nr_samples):
x, y = test_dataset[i]
x_batch.append(np.array(x))
y_batch.append(np.array(y))
x_batch = np.array(x_batch)
y_batch = np.array(y_batch).reshape(-1)
print(x_batch.shape, y_batch.shape)
(100, 3, 28, 28) (100,)
# Let's list the default parameters of the metric.
quantus.FaithfulnessCorrelation().get_params
{'abs': False, 'normalise': True, 'return_aggregate': True, 'aggregate_func': <function mean at 0x7870f2d29e30>, 'normalise_func': <function quantus.functions.normalise_func.normalise_by_max(a: numpy.ndarray, normalise_axes: Optional[Sequence[int]] = None) -> numpy.ndarray>, 'normalise_func_kwargs': {}, '_disable_warnings': False, '_display_progressbar': False, 'a_axes': None, 'similarity_func': <function quantus.functions.similarity_func.correlation_pearson(a: <built-in function array>, b: <built-in function array>, **kwargs) -> float>, 'nr_runs': 100, 'subset_size': 224, 'perturb_func': functools.partial(<function baseline_replacement_by_indices at 0x786f805155a0>, perturb_baseline='black')}
# Define some parameter settings to evaluate.
baseline_strategies = ["mean", "uniform", "black", "white"]
subset_sizes = np.array([1, 10, 25])
sim_funcs = {"pearson": quantus.correlation_pearson, "spearman": quantus.correlation_spearman}
result = {
"Faithfulness score": [],
"Method": [],
"Similarity function": [],
"Baseline strategy": [],
"Subset size": [],
}
xai_methods_with_kwargs = {"Saliency": {}, "IntegratedGradients": {"n_steps": 10}, "GradientShap": {}}
# Score explanations!
for b in baseline_strategies:
for s in subset_sizes:
for method, kwargs in xai_methods_with_kwargs.items():
for sim, sim_func in sim_funcs.items():
metric = quantus.FaithfulnessCorrelation(abs=False,
normalise=True,
return_aggregate=True,
disable_warnings=True,
aggregate_func=np.mean,
normalise_func=quantus.normalise_by_negative,
nr_runs=10,
perturb_baseline=b,
perturb_func=quantus.baseline_replacement_by_indices,
similarity_func=sim_func,
subset_size=s)
score = metric(model=model.cuda(),
x_batch=x_batch,
y_batch=y_batch,
a_batch=None,
explain_func=quantus.explain,
explain_func_kwargs=kwargs,
device=device)
result["Method"].append(method)
result["Baseline strategy"].append(b.capitalize())
result["Subset size"].append(s)
result["Faithfulness score"].append(score[0])
result["Similarity function"].append(sim)
df = pd.DataFrame(result)
df.head()
Faithfulness score | Method | Similarity function | Baseline strategy | Subset size | |
---|---|---|---|---|---|
0 | -0.035094 | Saliency | pearson | Mean | 1 |
1 | 0.003210 | Saliency | spearman | Mean | 1 |
2 | -0.098253 | IntegratedGradients | pearson | Mean | 1 |
3 | -0.022625 | IntegratedGradients | spearman | Mean | 1 |
4 | -0.142600 | GradientShap | pearson | Mean | 1 |
# Group by the ranking.
df["Rank"] = df.groupby(['Baseline strategy', 'Subset size', 'Similarity function'])["Faithfulness score"].rank()
# Smaller adjustments.
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
df.columns = map(lambda x: str(x).capitalize(), df.columns)
df.head(10)
Faithfulness score | Method | Similarity function | Baseline strategy | Subset size | Rank | |
---|---|---|---|---|---|---|
0 | -0.035094 | Saliency | pearson | Mean | 1 | 3.0 |
1 | 0.003210 | Saliency | spearman | Mean | 1 | 3.0 |
2 | -0.098253 | IntegratedGradients | pearson | Mean | 1 | 2.0 |
3 | -0.022625 | IntegratedGradients | spearman | Mean | 1 | 1.0 |
4 | -0.142600 | GradientShap | pearson | Mean | 1 | 1.0 |
5 | -0.002215 | GradientShap | spearman | Mean | 1 | 2.0 |
6 | -0.115930 | Saliency | pearson | Mean | 10 | 1.0 |
7 | -0.121939 | Saliency | spearman | Mean | 10 | 1.0 |
8 | -0.103789 | IntegratedGradients | pearson | Mean | 10 | 2.0 |
9 | -0.063273 | IntegratedGradients | spearman | Mean | 10 | 3.0 |
How does this relate back to our (naive) intution about how XAI methods should rank (consistently)?
# Group by rank and calculate percentage.
df_view = df.groupby(["Method"])["Rank"].value_counts(normalize=True).mul(100).reset_index(name='Percentage').round(2)
# Manually adding rows for 'Method A', 'Method B', and 'Method C'.
additional_rows = pd.DataFrame({
'Method': ['Method A', 'Method B', 'Method C'],
'Rank': [1.0, 2.0, 3.0],
'Percentage': [100, 100, 100]
})
# Use pd.concat to append the additional rows.
df_view = pd.concat([df_view, additional_rows], ignore_index=True)
# Preparing the ordered DataFrame.
df_view_ordered = pd.DataFrame({
'Method': ['Method A', 'Method B', 'Method C'],
'Rank': [1.0, 2.0, 3.0],
'Percentage': [100, 100, 100]
})
# Append the other methods based on the existing df_view DataFrame.
for method in xai_methods_with_kwargs:
df_view_ordered = pd.concat([df_view_ordered, df_view.loc[df_view["Method"] == method]], ignore_index=True)
df_view_ordered
Method | Rank | Percentage | |
---|---|---|---|
0 | Method A | 1.0 | 100.00 |
1 | Method B | 2.0 | 100.00 |
2 | Method C | 3.0 | 100.00 |
3 | Saliency | 1.0 | 41.67 |
4 | Saliency | 3.0 | 37.50 |
5 | Saliency | 2.0 | 20.83 |
6 | IntegratedGradients | 2.0 | 41.67 |
7 | IntegratedGradients | 3.0 | 33.33 |
8 | IntegratedGradients | 1.0 | 25.00 |
9 | GradientShap | 2.0 | 37.50 |
10 | GradientShap | 1.0 | 33.33 |
11 | GradientShap | 3.0 | 29.17 |
# Ensure default matplotlib settings are reset, affecting label visibility.
plt.rcdefaults()
fig, ax = plt.subplots(figsize=(6.5,5))
# Plot results!
ax = sns.histplot(x='Method', hue='Rank', weights='Percentage', multiple='stack', data=df_view_ordered, shrink=0.6, palette="colorblind", legend=False)
ax.spines["right"].set_visible(False)
ax.spines['top'].set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.set_ylabel('Frequency of rank', fontsize=15)
ax.set_xlabel(' Expectation vs Reality ', fontsize=15)
ax.set_xticklabels(["A", "B", "C", "SAL", "GS", "IG"], fontsize=12)
ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter())
plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=4, fancybox=True, shadow=False, labels=['1st', "2nd", "3rd", "4th"])
plt.axvline(x=2.5, ymax=0.95, color='black', linestyle='-')
plt.tight_layout()
Contrary to intution where ranking is consistent over different metric parameterisations, the ranking significantly differ in the different experimental settings.
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import quantus
We answer the following research question:
from datasets import load_dataset
BATCH_SIZE = 10
dataset = load_dataset("sst2", split="test")
x_batch = dataset['sentence'][BATCH_SIZE:BATCH_SIZE+BATCH_SIZE]
x_batch
Downloading readme: 0%| | 0.00/5.27k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/3.11M [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/72.8k [00:00<?, ?B/s]
Downloading data: 0%| | 0.00/148k [00:00<?, ?B/s]
Generating train split: 0%| | 0/67349 [00:00<?, ? examples/s]
Generating validation split: 0%| | 0/872 [00:00<?, ? examples/s]
Generating test split: 0%| | 0/1821 [00:00<?, ? examples/s]
["it 's also heavy-handed and devotes too much time to bigoted views .", 'it helps that lil bow wow ... tones down his pint-sized gangsta act to play someone who resembles a real kid .', 'watching the film is like reading a times portrait of grief that keeps shifting focus to the journalist who wrote it .', "moore 's performance impresses almost as much as her work with haynes in 1995 's safe .", 'reinforces the talents of screenwriter charlie kaufman , creator of adaptation and being john malkovich .', 'now trimmed by about 20 minutes , this lavish three-year-old production has enough grandeur and scale to satisfy as grown-up escapism .', 'a journey through memory , a celebration of living , and a sobering rumination on fatality , classism , and ignorance .', 'a remarkable 179-minute meditation on the nature of revolution .', 'waydowntown is by no means a perfect film , but its boasts a huge charm factor and smacks of originality .', "it 's just incredibly dull ."]
from transformers import AutoModelForSequenceClassification, AutoTokenizer, set_seed
set_seed(42)
# Load the model and tokenizer.
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME).cuda()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
config.json: 0%| | 0.00/629 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/268M [00:00<?, ?B/s]
tokenizer_config.json: 0%| | 0.00/48.0 [00:00<?, ?B/s]
vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]
# Load an example.
REFERENCE_TEXT = "The quick brown fox jumps over the lazy dog"
tokenizer(REFERENCE_TEXT, return_tensors="pt")
{'input_ids': tensor([[ 101, 1996, 4248, 2829, 4419, 14523, 2058, 1996, 13971, 3899, 102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
# @title 7.2.1 Supporting code
%%capture
from torch.utils.data import DataLoader
def preprocess_function(dataset):
# Tokenize the text.
return tokenizer(
dataset["sentence"],
padding=True,
truncation=True,
# max_length=100,
return_tensors="pt",
)
# Apply the tokenization to the entire dataset and convert format to PyTorch tensors.
processed_dataset = dataset.map(preprocess_function, batched=True)
processed_dataset.set_format(
type="torch",
columns=[
"input_ids",
"attention_mask",
"label",
],
)
# Save in data loader.
data_loader = DataLoader(processed_dataset, batch_size=BATCH_SIZE)
for b_ix, batch in enumerate(data_loader):
# Extract input tensors from the current batch.
inputs = {
k: v.to(torch.long).to(device)
for k, v in batch.items()
if k in ["input_ids", "attention_mask"]
}
# Perform model inference.
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=1)
print(predictions)
if b_ix == 0:
break
# x_batch = inputs
y_batch = predictions
# @title 7.3.1 Supporting code
%%capture
from captum.attr import LayerIntegratedGradients, IntegratedGradients
from IPython.display import display, HTML
def explain_with_layer_ig(model, inputs, targets, **kwargs):
"""Explain with Layer Integrated Gradients."""
model.eval()
model.zero_grad()
layer = kwargs.get("layer")
tokenizer = kwargs.get("tokenizer")
ref_token_id = tokenizer.pad_token_id
def predict(input_ids, attention_mask=None):
# Special predict func.
outputs = model(input_ids.to(torch.int), attention_mask=attention_mask)
return outputs.logits.max(1).values
explanations = []
explanations = np.empty(len(x_batch), dtype=object)
for i, text in enumerate(inputs):
def construct_input_ref_pair(text, ref_token_id):
# Construct reference token ids.
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
ref_input_ids = torch.zeros_like(input_ids)
ref_input_ids[:] = ref_token_id
return input_ids, ref_input_ids, attention_mask
# Construct the necessary pairs.
input_ids, ref_input_ids, attention_mask = construct_input_ref_pair(
text=text, ref_token_id=ref_token_id
)
# Explain with IG.
lig = LayerIntegratedGradients(predict, eval(layer))
explanation = lig.attribute(
inputs=input_ids.to(device),
#inputs=(input_ids.to(device), attention_mask.to(device)),
baselines=ref_input_ids.cuda(),
return_convergence_delta=False,
additional_forward_args=(attention_mask.to(device),)
)
# Sum over the layers.
if len(explanation.shape) > 2:
explanation = explanation.sum(dim=2)
explanations[i] = explanation.squeeze().cpu().numpy()
return explanations
def colorize_words(text, scores, normalise: bool = False):
"""
Returns an HTML string with words colorized based on their explanation scores.
"""
if normalise:
scores = (scores - np.min(scores)) / (np.max(scores) - np.min(scores))
cmap = plt.get_cmap("Reds")
# Get tokens.
tokens = tokenizer.tokenize(text)
html_string = "<div style='font-family: Arial;'>"
for token, score in zip(tokens, scores):
color = cmap(score)
hex_color = matplotlib.colors.rgb2hex(color[:3])
html_string += f"<span style='background-color: {hex_color};'>{token} </span>"
html_string += "</div>"
return html_string
def get_label(pred):
if pred == 0:
return "Negative"
return "Positive"
# Generate explanations.
explanations = explain_with_layer_ig(
model=model,
inputs=x_batch,
targets=y_batch,
**{
"attention_mask": None,
"tokenizer": tokenizer,
"layer": "model.distilbert.embeddings",
},
)
# Plot!
for text, label, explanation in zip(x_batch, y_batch, explanations):
html_string = colorize_words(text, explanation, )
print(f"\nSentiment Prediction = {get_label(label)}")
display(HTML(html_string))
Sentiment Prediction = Negative
Sentiment Prediction = Negative
Sentiment Prediction = Positive
Sentiment Prediction = Positive
Sentiment Prediction = Positive
Sentiment Prediction = Positive
Sentiment Prediction = Negative
Sentiment Prediction = Positive
Sentiment Prediction = Positive
Sentiment Prediction = Negative
For example, we evaluate the complexity of the LayerIntegratedGradients
of the first embedding layer model.distilbert.embeddings
.
# Pick a metric.
metric = quantus.Sparseness()
scores = []
for expl, input, y, in zip(explanations, x_batch, y_batch):
score = metric(model, x_batch=np.expand_dims(np.zeros_like(expl), axis=0), y_batch=[y.item()], a_batch=np.expand_dims(expl, axis=0))
scores.append(score[0])
print(f"Complexity score of LayerIntegratedGradient explanation: {np.mean(scores):.2f} ({np.std(scores):.2f})")
Complexity score of LayerIntegratedGradient explanation: 0.47 (0.07)
Identifying Reliable Estimators with MetaQuantus](https://arxiv.org/abs/2302.07265) by Hedström et al., 2023