Source code for mdai.utils.keras_utils
from mdai.visualize import load_dicom_image
from keras.utils import Sequence, to_categorical
import numpy as np
from PIL import Image
[docs]class DataGenerator(Sequence):
def __init__(
self,
dataset,
batch_size=32,
dim=(32, 32),
n_channels=1,
n_classes=10,
shuffle=True,
to_RGB=True,
rescale=False,
):
"""Generates data for Keras fit_generator() function.
"""
# Initialization
self.dim = dim
self.batch_size = batch_size
self.img_ids = dataset.image_ids
self.imgs_anns_dict = dataset.imgs_anns_dict
self.dataset = dataset
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.to_RGB = to_RGB
self.rescale = rescale
self.on_epoch_end()
def __len__(self):
"Denotes the number of batches per epoch"
return int(np.floor(len(self.img_ids) / self.batch_size))
def __getitem__(self, index):
"Generate one batch of data"
# Generate indexes of the batch
indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size]
# Find list of IDs
img_ids_temp = [self.img_ids[k] for k in indexes]
# Generate data
X, y = self.__data_generation(img_ids_temp)
return X, y
[docs] def on_epoch_end(self):
"Updates indexes after each epoch"
self.indexes = np.arange(len(self.img_ids))
if self.shuffle:
np.random.shuffle(self.indexes)
def __data_generation(self, img_ids_temp):
"Generates data containing batch_size samples"
# Initialization
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size), dtype=int)
# Generate data
for i, ID in enumerate(img_ids_temp):
image = load_dicom_image(ID, to_RGB=self.to_RGB, rescale=self.rescale)
try:
image = Image.fromarray(image)
except Exception:
print(
"Pil.Image can't read image. Possible 12 or 16 bit image. Try rescale=True to "
+ "scale to 8 bit."
)
image = image.resize((self.dim[0], self.dim[1]))
X[i,] = image
ann = self.imgs_anns_dict[ID][0]
y[i] = self.dataset.classes_dict[ann["labelId"]]["class_id"]
return X, to_categorical(y, num_classes=self.n_classes)