Training pipeline
path = Path('images')
files = get_image_files(path)
def img2rgb(img):
return img.convert('RGB')
datasets = Datasets(
files,
tfms=[[PILImage.create, img2rgb], [parent_label, Categorize]],
splits=RandomSplitter()(files)
)
dataloaders = datasets.dataloaders(
after_item=[ToTensor, RandomResizedCrop(224)],
after_batch=[IntToFloatTensor],
bs=128,
)
dataloaders.show_batch()
model = nn.Sequential(OrderedDict([
('body', create_body(xresnet50, cut=-4)),
('head', create_head(2048, 151, lin_ftrs=[2048]))
]))
learner = Learner(
dataloaders,
model,
opt_func=Adam,
loss_func=CrossEntropyLossFlat(),
metrics=[accuracy, top_k_accuracy],
cbs=[SaveModelCallback, WandbCallback]
)
Actual training was 400 epochs
Only 1 cycle shown here for brevity
from experiments.utils import *
class Hook():
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_func)
def hook_func(self, m, i, o): self.stored = o.detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
class HookBwd():
def __init__(self, m):
self.hook = m.register_backward_hook(self.hook_func)
def hook_func(self, m, gi, go): self.stored = go[0].detach().clone()
def __enter__(self, *args): return self
def __exit__(self, *args): self.hook.remove()
@patch
def show_activations(self: Learner, img):
img = Path(img)
pipeline = Pipeline([PILImage.create, ToTensor, Resize(224, method='squish'), IntToFloatTensor])
x = pipeline(img).unsqueeze(0)
y, name = img.parent.name.split('-')
y = int(y)
with HookBwd(learner.model.body) as hookg:
with Hook(learner.model.body) as hook:
output = learner.model.eval()(x)
act = hook.stored
output[0,y].backward()
grad = hookg.stored
w = grad[0].mean(dim=[1,2], keepdim=True)
cam_map = (w * act[0]).sum(0)
_,ax = plt.subplots()
x[0].show(ctx=ax)
ax.title.set_text(name)
ax.imshow(cam_map.detach().cpu(), alpha=0.80, extent=(0,224,224,0),
interpolation='sinc', cmap='magma_r')
for _ in range(3):
learner.show_activations(random.choice(files))
learner.show_activations('images/059-Arcanine/Image_3.jpg')
learner.show_activations('images/030-Nidorina/Image_8.jpg')
learner.show_activations('images/025-Pikachu/Image_11.jpg')
Classification results look good. Many images are classified correctly, including variants like pencil drawings, 2d art, and crochet.
There are two labels which might be concerning:
- 097-Hypno is actually a picture of Jimmy Neutron
- 147-Dratini has multiple images of Dratini, Dragonair, and Dragonite
learner.show_results()
interp = ClassificationInterpretation.from_learner(learner)
interp.plot_top_losses(4, figsize=(9,6))
Need to explicitly remove wandb
dependencies from learner
# learner.cbs = learner.cbs[:-2]