Source code for deep_lincs.models.base_network

class BaseNetwork:
    """A wrapper to train and evaluate a Keras model on a Dataset. 

    Parameters
    ----------
    dataset : ``Dataset``
            An instance of a ``Dataset`` intended to train and evaluate a model.
            
    target : ``str``
            Valid metadata field or "self". Defines classification task or whether model in an autoencoder.
            
    test_sizes : tuple, (optional, default ( ``0.2`` , ``0.2`` ))
            Size of test splits for dividing the dataset into training, validation, and, testing

    Attributes
    ----------
    target : ``str``
            Target task of model.
            
    train : ``Dataset``
            Dataset used to train the model.
    
    val : ``Dataset``
            Dataset used during training as validation.
    
    test : ``Dataset``
            Dataset used to evaluate the model.
            
    model : ``tensorflow.keras.Model``
            Compiled and trained model.
    """
    def __init__(self, dataset, target, test_sizes=(0.2, 0.2)):
        self.target = target
        self.train, self.val, self.test = dataset.train_val_test_split(*test_sizes)
        self.model = None
        self._dataset_preprocessed = False

    def prepare_tf_datasets(self, batch_size, batch_normalize=None):
        """Defines how to prepare a prefetch dataset for training and model evaluation

        Parameters
        ----------
        batch_size : ``int``
                Batch size during training and for model evaluation. 
        
        batch_normalize : ``str`` (default: ``None``)
                Normalization applied to each batch during training and evaluation.
                Can be one of ``"z_score"`` or ``"standard_scale"``. Default is ``None``.
                
        Returns
        -------
                ``None``
        
        >>> model.prepare_tf_datasets(batch_size=128)
        """
        self._batch_size = batch_size
        self.train_dset, self.val_dset, self.test_dset = [
            lincs_dset(self.target, batch_size, batch_normalize)
            for lincs_dset in [self.train, self.val, self.test]
        ]
        self._dataset_preprocessed = True

    def compile_model(self):
        pass

    def fit(self, epochs=5, shuffle=True, **kwargs):
        """Trains model on training dataset

        Parameters
        ----------
        epochs : ``int``
                Number of training epochs 
        
        shuffle : ``bool`` (default: ``True``)
                Whether to shuffle batches during training.
        
        kwargs : (optional)
                Additional keyword arguments for ``tensorflow.keras.model.fit``. 
                This is where ``tensorflow.keras.callbacks`` should be supplied, such
                as Tensorboard or EarlyStopping.
                
        Returns
        -------
                ``None``
        """
        if self._dataset_preprocessed is False:
            raise ValueError(
                f"Data has not been prepared for training. "
                f"Run {self.__class__.__name__}.prepare_tf_datasets()."
            )
        if self.model is None:
            raise ValueError(
                f"Model has not been created. "
                f"Run the {self.__class__.__name__}.compile_model() method before training."
            )
        self.model.fit(
            self.train_dset,
            epochs=epochs,
            shuffle=shuffle,
            steps_per_epoch=len(self.train) // self._batch_size,
            validation_data=self.val_dset,
            **kwargs,
        )

    def evaluate(self, inputs=None):
        """Evaluates model

        Parameters
        ----------
        inputs : ``tensorflow.data.dataset``, (optional: default ``None``)
                If no tf.dataset is provided, the model is evaluated on internal 
                test dataset.
                
        Returns
        -------
                ``list`` of evalutation metrics.
        """
        if inputs is None:
            return self.model.evaluate(self.test_dset)
        else:
            return self.model.evaluate(inputs)

    def predict(self, inputs=None):
        """Feeds inputs forward through the network

        Parameters
        ----------
        inputs : ``tensorflow.data.dataset`` or ``array`` or ``dataframe``, (optional: default ``None``)
                Inputs fed through the network. If not provided, the model uses the 
                internal testing data to make a prediction.
                
        Returns
        -------
                ``array`` of final activations.
        """
        if inputs is None:
            return self.model.predict(self.test_dset)
        else:
            return self.model.predict(inputs)

    def save(self, file_name):
        """Saves model as hdf5

        Parameters
        ----------
        file_name : ``str``
                Name of output file.
                
        Returns
        -------
                ``None``
        """
        self.model.save(file_name)

    def summary(self):
        """Prints verbose summary of model
                
        Returns
        -------
                ``None``
        """
        return self.model.summary()