This repo provides a PyTorch re-implementation of the Basenji2 model published in "Cross-species regulatory sequence activity prediction" by David Kelley. This implementation was checked by verifying that the Tensorflow and PyTorch version yielded the same output on random data. Small deviations were found, likely due to differences in the underlying algorithms used by Tensorflow and PyTorch (e.g. different matrix multiplication algorithms).
On Linux with conda/mamba:
- Clone the repository.
- Add it to your PYTHONPATH environment variable (i.e. in your
.bashrc
file). - Use conda/mamba to install dependencies from the
environment.yml
found in the repo. - Download the PyTorch weights.
import json
import torch
from basenji2_pytorch import Basenji2, params # or PLBasenji2 to use training parameters from Kelley et al. 2020
model_weights = 'path/to/basenji2.pth'
with open(params) as params_open:
model_params = json.load(params_open)['model']
# to use a headless model e.g. for transfer learning
# model_params.pop("head_human", None)
basenji2 = Basenji2(model_params)
basenji2.load_state_dict(torch.load(model_weights), strict=False)