Source code for mdai.utils.keras_utils

from mdai.visualize import load_dicom_image
from keras.utils import Sequence, to_categorical

import numpy as np
from PIL import Image


[docs]class DataGenerator(Sequence): def __init__( self, dataset, batch_size=32, dim=(32, 32), n_channels=1, n_classes=10, shuffle=True, to_RGB=True, rescale=False, ): """Generates data for Keras fit_generator() function. """ # Initialization self.dim = dim self.batch_size = batch_size self.img_ids = dataset.image_ids self.imgs_anns_dict = dataset.imgs_anns_dict self.dataset = dataset self.n_channels = n_channels self.n_classes = n_classes self.shuffle = shuffle self.to_RGB = to_RGB self.rescale = rescale self.on_epoch_end() def __len__(self): "Denotes the number of batches per epoch" return int(np.floor(len(self.img_ids) / self.batch_size)) def __getitem__(self, index): "Generate one batch of data" # Generate indexes of the batch indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size] # Find list of IDs img_ids_temp = [self.img_ids[k] for k in indexes] # Generate data X, y = self.__data_generation(img_ids_temp) return X, y
[docs] def on_epoch_end(self): "Updates indexes after each epoch" self.indexes = np.arange(len(self.img_ids)) if self.shuffle: np.random.shuffle(self.indexes)
def __data_generation(self, img_ids_temp): "Generates data containing batch_size samples" # Initialization X = np.empty((self.batch_size, *self.dim, self.n_channels)) y = np.empty((self.batch_size), dtype=int) # Generate data for i, ID in enumerate(img_ids_temp): image = load_dicom_image(ID, to_RGB=self.to_RGB, rescale=self.rescale) try: image = Image.fromarray(image) except Exception: print( "Pil.Image can't read image. Possible 12 or 16 bit image. Try rescale=True to " + "scale to 8 bit." ) image = image.resize((self.dim[0], self.dim[1])) X[i,] = image ann = self.imgs_anns_dict[ID][0] y[i] = self.dataset.classes_dict[ann["labelId"]]["class_id"] return X, to_categorical(y, num_classes=self.n_classes)