"With deep learning, we're not just reading the genome, we're understanding it."- Gemini 2024
This project utilizes the Genomics OOD dataset from TensorFlow, which contains DNA sequences for detecting out-of-distribution (OOD) samples in genomic data. The dataset includes:
The project implements:
pip install tensorflow tensorflow-datasets
import tensorflow as tf import tensorflow_datasets as tfds class OODValCallback(tf.keras.callbacks.Callback): def __init__(self, val_ood): super(OODValCallback, self).__init__() self.val_ood = val_ood def on_epoch_end(self, epoch, logs=None): loss, accuracy = self.model.evaluate(self.val_ood, verbose=0) print(f' - val_ood_accuracy: {accuracy:.4f}, val_ood_loss: {loss:.4f}') def load_data(): dataset = tfds.load('genomics_ood', as_supervised=True) train = dataset['train'] val = dataset['validation'] val_ood = dataset['validation_ood'] test = dataset['test'] test_ood = dataset['test_ood'] return train, val, val_ood, test, test_ood def preprocess_data(dataset, seq_len=250): def preprocess(seq, label): # Encode nucleotides as integers table = tf.lookup.StaticHashTable( initializer=tf.lookup.KeyValueTensorInitializer( keys=tf.constant(['A', 'T', 'G', 'C', 'N']), values=tf.constant([0, 1, 2, 3, 4]) ), default_value=tf.constant(-1) ) seq = tf.strings.unicode_split(seq, 'UTF-8') seq = table.lookup(seq) seq = seq[:seq_len] # truncate seq = tf.pad(seq, [[0, seq_len - tf.shape(seq)[0]]]) # pad return seq, label return dataset.map(preprocess).batch(32).prefetch(tf.data.AUTOTUNE) def create_model(input_shape=(100,)): model = tf.keras.Sequential([ tf.keras.layers.Embedding(input_dim=5, output_dim=128), tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True )), tf.keras.layers.GlobalMaxPooling1D(), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ]) model.compile( optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] ) return model def main(): print('Loading data...') train, val, val_ood, test, test_ood = load_data() # uncomment for fast code testing # train = train.take(100) # val = val.take(50) # val_ood = val_ood.take(50) # test = test.take(50) # test_ood = test_ood.take(50) print('Preprocessing data...') train = preprocess_data(train) val = preprocess_data(val) val_ood = preprocess_data(val_ood) test = preprocess_data(test) test_ood = preprocess_data(test_ood) ood_val_callback = OODValCallback(val_ood) print('Creating model...') model = create_model() model.summary() print('Training model...') model.fit( train, validation_data=val, epochs=10, callbacks=[ood_val_callback] ) print('Evaluating model on in-distribution test set...') test_id_loss, test_id_acc = model.evaluate(test) print(f'In-distribution Test Accuracy: {test_id_acc * 100:.2f}%') print('Evaluating model on out-of-distribution test set...') test_ood_loss, test_ood_acc = model.evaluate(test_ood) print(f'Out-of-distribution Test Accuracy: {test_ood_acc * 100:.2f}%') if __name__ == '__main__': main()