Genomics OOD with TensorFlow

"With deep learning, we're not just reading the genome, we're understanding it."- Gemini 2024

Genomics Out-of-Distribution Detection Project

Dataset Overview

This project utilizes the Genomics OOD dataset from TensorFlow, which contains DNA sequences for detecting out-of-distribution (OOD) samples in genomic data. The dataset includes:

  • Training sequences from in-distribution genomic data
  • Validation sequences (both in-distribution and OOD)
  • Test sequences (both in-distribution and OOD)
  • DNA sequences encoded as strings of nucleotides (A, T, G, C, N)

Analysis Details

The project implements:

  • Sequence preprocessing with nucleotide encoding
  • A deep learning model using bidirectional LSTM architecture
  • Custom callback for OOD validation monitoring
  • Separate evaluation on in-distribution and OOD test sets

Step-by-Step Replication Guide

  1. Install required dependencies: pip install tensorflow tensorflow-datasets
  2. Load and preprocess the data:
    • Load genomics_ood dataset TensorFlow Dataset
    • Convert nucleotides to integer encodings
    • Pad/truncate sequences to uniform length
  3. Create and compile the model:
    • Embedding layer for nucleotide representation
    • Bidirectional LSTM for sequence processing
    • Dense layers for classification
  4. Train the model:
    • 10 epochs with Adam optimizer
    • Monitor both in-distribution and OOD validation accuracy
  5. Evaluate performance on test sets
import tensorflow as tf
import tensorflow_datasets as tfds

class OODValCallback(tf.keras.callbacks.Callback):
    def __init__(self, val_ood):
        super(OODValCallback, self).__init__()
        self.val_ood = val_ood

    def on_epoch_end(self, epoch, logs=None):
        loss, accuracy = self.model.evaluate(self.val_ood, verbose=0)
        print(f' - val_ood_accuracy: {accuracy:.4f}, val_ood_loss: {loss:.4f}')

def load_data():
    dataset = tfds.load('genomics_ood', as_supervised=True)
    train = dataset['train']
    val = dataset['validation']
    val_ood = dataset['validation_ood']
    test = dataset['test']
    test_ood = dataset['test_ood']
    return train, val, val_ood, test, test_ood

def preprocess_data(dataset, seq_len=250):
    def preprocess(seq, label):
        # Encode nucleotides as integers
        table = tf.lookup.StaticHashTable(
            initializer=tf.lookup.KeyValueTensorInitializer(
                keys=tf.constant(['A', 'T', 'G', 'C', 'N']),
                values=tf.constant([0, 1, 2, 3, 4])
            ),
            default_value=tf.constant(-1)
        )
        seq = tf.strings.unicode_split(seq, 'UTF-8')
        seq = table.lookup(seq)
        seq = seq[:seq_len] # truncate
        seq = tf.pad(seq, [[0, seq_len - tf.shape(seq)[0]]]) # pad
        return seq, label

    return dataset.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE)

def create_model(input_shape=(100,)):
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(input_dim=5, output_dim=128),
        tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128,
            return_sequences=True
        )),
        tf.keras.layers.GlobalMaxPooling1D(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ])
    model.compile(
        optimizer='adam',
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    return model

def main():
    print('Loading data...')
    train, val, val_ood, test, test_ood = load_data()

    # uncomment for fast code testing
    # train = train.take(100)
    # val = val.take(50)
    # val_ood = val_ood.take(50)
    # test = test.take(50)
    # test_ood = test_ood.take(50)

    print('Preprocessing data...')
    train = preprocess_data(train)
    val = preprocess_data(val)
    val_ood = preprocess_data(val_ood)
    test = preprocess_data(test)
    test_ood = preprocess_data(test_ood)

    ood_val_callback = OODValCallback(val_ood)

    print('Creating model...')
    model = create_model()
    model.summary()

    print('Training model...')
    model.fit(
        train,
        validation_data=val,
        epochs=10,
        callbacks=[ood_val_callback]
    )

    print('Evaluating model on in-distribution test set...')
    test_id_loss, test_id_acc = model.evaluate(test)
    print(f'In-distribution Test Accuracy: {test_id_acc * 100:.2f}%')

    print('Evaluating model on out-of-distribution test set...')
    test_ood_loss, test_ood_acc = model.evaluate(test_ood)
    print(f'Out-of-distribution Test Accuracy: {test_ood_acc * 100:.2f}%')

if __name__ == '__main__':
    main()


Improvement Tips

  • Implement data augmentation techniques specific to genomic sequences
  • Experiment with different model architectures (e.g., Transformers, CNN-LSTM hybrids)
  • Add regularization techniques to prevent overfitting
  • Implement k-fold cross-validation for more robust evaluation
  • Try different sequence lengths and batch sizes
  • Add attention mechanisms to improve sequence feature extraction
  • Implement ensemble methods using multiple model architectures
→ This page was created with help from Gemini, Claude AI, and ChatGPT.