vishal bakshi
september 5, 2020
jeremy howard urges students (and practitioners) not to overestimate the amount of data that is needed to train a high performing model given that we can use transfer learning. i want to experiment with that concept by training an image classifier for the following number of images and compare the results:
#!tar chvfz notebook.tar.gz *
from utils import *
from fastai.vision.all import *
from fastai.vision.widgets import *
first, i'll check to see that i placed the images into the correct folders
Image.open('bears2/grizzly/00000000.jpg').to_thumb(128,128)
Image.open('bears2/black/00000000.jpg').to_thumb(128,128)
Image.open('bears2/teddy/00000000.jpg').to_thumb(128,128)
bears = DataBlock(
blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=parent_label,
item_tfms=Resize(128))
before i move on, i want to understand the DataBlock
a bit more. so far, it's an empty container with two blocks (ImageBlock
and CategoryBlock
), a function to get items to store in the DataBlock
(get_image_files
), a function to split the items into a training and validation set RandomSplitter
, a function to get the dependent variable y (parent_label
) and a transform to apply to each item (Resize()
)
path = Path('bears2')
print('# of total images:',len(bears.get_items(path)))
print('# of training images:',len(bears.splitter(bears.get_items(path))[0]))
print('# of validation images:',len(bears.splitter(bears.get_items(path))[1]))
print('image labels: ',
bears.get_y(bears.get_items(path)[0]),",",
bears.get_y(bears.get_items(path)[10]),",",
bears.get_y(bears.get_items(path)[20]))
print('image after item_tfms is applied to it',bears.item_tfms[1](bears.datasets(path)[0]))
plt.imshow(bears.item_tfms[1](bears.datasets(path)[0])[0])
dls = bears.dataloaders(path, batch_size=2)
dls.valid.show_batch(max_n=5,nrows=1)
learn = cnn_learner(dls, resnet18, metrics=error_rate)
learn.fine_tune(4)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
interp.plot_top_losses(4, nrows=2)