import pickle
from bio_embeddings.embed import ProtTransBertBFDEmbedder

embedder = ProtTransBertBFDEmbedder(device = 'cpu')

with open('../pickles/flatiron_10_data.pkl', 'rb') as f:
    data_dict = pickle.load(f)
    
try:
    with open('bfd_dict.pkl', 'rb') as f:
        bfd_dict = pickle.load(f)
except:
    bfd_dict = {}
        
for i, data in enumerate(data_dict['X_cc']):
    
    seq = data[0]
    
    if seq not in bfd_dict:
        
        embed = embedder.reduce_per_protein(embedder.embed(seq))

        bfd_dict[seq] = embed

        if i % 100 == 0:

            print(f'embedded {i+1} proteins')
            with open('bfd_dict.pkl', 'wb') as f:
                pickle.dump(bfd_dict, f)
            
with open('bfd_dict.pkl', 'wb') as f:
            pickle.dump(bfd_dict, f)
