Source code for diffnets.analysis

import mdtraj as md
import numpy as np
import itertools
from . import utils
import multiprocessing as mp
import os
import functools
from torch.autograd import Variable
import torch
from scipy import stats
from scipy.stats import pearsonr
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
import enspara
import enspara.cluster as cluster
import enspara.info_theory as infotheor
import enspara.msm as msm
import enspara.cluster as cluster
import enspara.info_theory as infotheor
import pickle
import scipy.sparse
import sys
from pylab import *
from torch.autograd import Variable
from collections import defaultdict

[docs]class Analysis: """Core object for running analysis. Parameters ---------- net : nnutils object Neural network to perform analysis with netdir : str path to directory with neural network results datadir : str path to directory with data required to train the data. Includes cm.npy, wm.npy, uwm.npy, master.pdb, an aligned_xtcs dir, and an indicators dir. """ def __init__(self, net, netdir, datadir): self.net = net self.netdir = netdir self.datadir = datadir self.top = md.load(os.path.join( self.datadir, "master.pdb")) self.cm = np.load(os.path.join(self.datadir, "cm.npy")) self.n_cores = mp.cpu_count()
[docs] def encode_data(self): """Calculate the latent space for all trajectory frames. """ enc_dir = os.path.join(self.netdir, "encodings") utils.mkdir(enc_dir) xtc_dir = os.path.join(self.datadir, "aligned_xtcs") encode_dir(self.net, xtc_dir, enc_dir, self.top, self.n_cores, self.cm)
[docs] def recon_traj(self): """Reconstruct all trajectory frames using the trained neural network""" recon_dir = os.path.join(self.netdir, "recon_trajs") utils.mkdir(recon_dir) enc_dir = os.path.join(self.netdir, "encodings") recon_traj_dir(self.net, enc_dir, recon_dir, self.top.top, self.cm, self.n_cores) print("trajectories reconstructed")
[docs] def get_labels(self): """Calculate the classification score for all trajectory frames """ label_dir = os.path.join(self.netdir, "labels") utils.mkdir(label_dir) enc_dir = os.path.join(self.netdir, "encodings") calc_labels(self.net, enc_dir, label_dir, self.n_cores) print("labels calculated for all states")
[docs] def get_rmsd(self): """Calculate RMSD between actual trajectory frames and autoencoder reconstructed frames""" rmsd_fn = os.path.join(self.netdir, "rmsd.npy") recon_dir = os.path.join(self.netdir, "recon_trajs") orig_xtc_dir = os.path.join(self.datadir, "aligned_xtcs") rmsd = rmsd_dists_dir(recon_dir, orig_xtc_dir, self.top, self.n_cores) np.save(rmsd_fn, rmsd)
[docs] def morph(self,n_frames=10): """Get representative structures for classification scores from 0 to 1. Parameters ---------- n_frames : int How many representative structures to output. Bins between 0 and 1 will be calculated with this number. """ morph_label(self.net,self.netdir,self.datadir,n_frames=n_frames)
[docs] def assign_labels_to_variants(self,plot_labels=False): """Map DiffNet labels to each variant with option to plot a histogram of the labels. Parameters ---------- plot_labels : optional, boolean Save a matplotlob figure of the label histogram. Returns ------- lab_v : dictionary Dictionary mapping labels to their respective variants. """ lab_fns = utils.get_fns(os.path.join(self.netdir,"labels"),"*.npy") traj_d_path = os.path.join(self.datadir,"traj_dict.pkl") traj_d = pickle.load(open(traj_d_path, 'rb')) lab_v = defaultdict(list) for key,item in traj_d.items(): for traj_ind in range(item[0],item[1]): lab = np.load(lab_fns[traj_ind]) lab_v[key].append(lab) if plot_labels: plt.figure(figsize=(16,16)) axes = plt.gca() lw = 8 for k in traj_d.keys(): t = np.concatenate(lab_v[k]) n, x = np.histogram(t, range=(0, 1), bins=50) plt.plot(x[:-1],n,label=k,linewidth=lw) plt.xticks(fontsize=36) plt.yticks(fontsize=36) axes.set_xlabel('DiffNet Label',labelpad=40, fontsize=36) axes.set_ylabel('# of Simulation Frames',labelpad=40,fontsize=36) axes.tick_params(direction='out', length=20, width=5, grid_color='r', grid_alpha=0.5) plt.legend(fontsize=36) for axis in ['top','bottom','left','right']: axes.spines[axis].set_linewidth(5) plt.savefig(os.path.join(self.netdir,"label_plot.png")) return lab_v
[docs] def find_feats(self,inds,out_fn,n_states=2000,num2plot=100,clusters=None): """Generate a .pml file that will show the distances that change in a way that is most with changes in the classifications score. Parameters ---------- inds : np.ndarray, Indices of the topology file that are to be included in calculating what distances are most correlated with classification score. out_fn : str Name of the output file. n_states : int (default=2000) How many cluster centers to calculate and use for correlation measurement. num2plot : int (default=100) Number of distances to be shown. clusters : enspara cluster object Cluster object with center_indices attribute """ if not clusters: cc_dir = os.path.join(self.netdir, "cluster_%d" % n_states) utils.mkdir(cc_dir) enc = utils.load_npy_dir(os.path.join(self.netdir, "encodings"), "*npy") if hasattr(self.net,"split_inds"): x = self.net.encoder1[-1].out_features enc = enc[:,:x] clusters = cluster.hybrid.hybrid(enc, euc_dist, n_clusters=n_states, n_iters=1) cluster_fn = os.path.join(cc_dir, "clusters.pkl") pickle.dump(clusters, open(cluster_fn, 'wb')) find_features(self.net,self.datadir,self.netdir, clusters.center_indices,inds,out_fn,num2plot=num2plot)
[docs] def run_core(self): """Wrapper to run the analysis functions that should be run after training. """ self.encode_data() self.recon_traj() self.get_labels() self.get_rmsd()
def euc_dist(trj, frame): diff = np.abs(trj - frame) try: d = np.sqrt(np.sum(diff * diff, axis=1)) except: d = np.array([np.sqrt(np.sum(diff * diff))]) return d def recon_traj(enc, net, top, cm): n = len(enc) n_atoms = top.n_atoms x = Variable(torch.from_numpy(enc).type(torch.FloatTensor)) coords = net.decode(x) coords = coords.detach().numpy() coords += cm coords = coords.reshape((n, n_atoms, 3)) traj = md.Trajectory(coords, top) return traj def _recon_traj_dir(enc_fn, net, recon_dir, top, cm): enc = np.load(enc_fn) traj = recon_traj(enc, net, top, cm) new_fn = os.path.split(enc_fn)[1] base_fn = os.path.splitext(new_fn)[0] new_fn = base_fn + ".xtc" new_fn = os.path.join(recon_dir, new_fn) traj.save(new_fn) def recon_traj_dir(net, enc_dir, recon_dir, top, cm, n_cores): enc_fns = utils.get_fns(enc_dir, "*.npy") pool = mp.Pool(processes=n_cores) f = functools.partial(_recon_traj_dir, net=net, recon_dir=recon_dir, top=top, cm=cm) pool.map(f, enc_fns) pool.close() def _calc_labels(enc_fn, net, label_dir): enc = np.load(enc_fn) if hasattr(net,"split_inds"): x = net.encoder1[-1].out_features enc = enc[:,:x] enc = Variable(torch.from_numpy(enc).type(torch.FloatTensor)) labels = net.classify(enc) labels = labels.detach().numpy() new_fn = os.path.split(enc_fn)[1] new_fn = os.path.join(label_dir, "lab" + new_fn) np.save(new_fn, labels) def calc_labels(net, enc_dir, label_dir, n_cores): enc_fns = utils.get_fns(enc_dir, "*npy") pool = mp.Pool(processes=n_cores) f = functools.partial(_calc_labels, net=net, label_dir=label_dir) pool.map(f, enc_fns) pool.close() def get_rmsd_dists(orig_traj, recon_traj): n_frames = len(recon_traj) if n_frames != len(orig_traj): # should raise exception print("Can't get rmsds between trajectories of different lengths") return pairwise_rmsd = [] for i in range(0, n_frames, 10): r = md.rmsd(recon_traj[i], orig_traj[i], parallel=False)[0] pairwise_rmsd.append(r) pairwise_rmsd = np.array(pairwise_rmsd) return pairwise_rmsd def _rmsd_dists_dir(recon_fn, orig_xtc_dir, ref_pdb): recon_traj = md.load(recon_fn, top=ref_pdb.top) base_fn = os.path.split(recon_fn)[1] orig_fn = os.path.join(orig_xtc_dir, base_fn) orig_traj = md.load(orig_fn, top=ref_pdb.top) pairwise_rmsd = get_rmsd_dists(orig_traj, recon_traj) return pairwise_rmsd def rmsd_dists_dir(recon_dir, orig_xtc_dir, ref_pdb, n_cores): recon_fns = utils.get_fns(recon_dir, "*.xtc") pool = mp.Pool(processes=n_cores) f = functools.partial(_rmsd_dists_dir, orig_xtc_dir=orig_xtc_dir, ref_pdb=ref_pdb) res = pool.map(f, recon_fns) pool.close() pairwise_rmsd = np.concatenate(res) return pairwise_rmsd def _encode_dir(xtc_fn, net, outdir, top, cm): traj = md.load(xtc_fn, top=top) n = len(traj) n_atoms = traj.top.n_atoms x = traj.xyz.reshape((n, 3*n_atoms))-cm x = Variable(torch.from_numpy(x).type(torch.FloatTensor)) if hasattr(net, 'split_inds'): lat1, lat2 = net.encode(x) output = torch.cat((lat1,lat2),1) else: output = net.encode(x) output = output.detach().numpy() new_fn = os.path.split(xtc_fn)[1] new_fn = os.path.splitext(new_fn)[0] + ".npy" new_fn = os.path.join(outdir, new_fn) np.save(new_fn, output) def encode_dir(net, xtc_dir, outdir, top, n_cores, cm): xtc_fns = utils.get_fns(xtc_dir, "*.xtc") pool = mp.Pool(processes=n_cores) f = functools.partial(_encode_dir, net=net, outdir=outdir, top=top, cm=cm) pool.map(f, xtc_fns) pool.close() def morph_label(net,nn_dir,data_dir,n_frames=10): pdb_fn = os.path.join(data_dir, "master.pdb") ref_s = md.load(pdb_fn) n_atoms = ref_s.top.n_atoms uwm_fn = os.path.join(data_dir, "uwm.npy") uwm = np.load(uwm_fn) cm_fn = os.path.join(data_dir, "cm.npy") cm = np.load(cm_fn) enc = utils.load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy") n_latent = int(enc.shape[1]) morph_dir = os.path.join(nn_dir, "morph_label") if not os.path.exists(morph_dir): os.mkdir(morph_dir) labels_dir = os.path.join(nn_dir,"labels") labels = utils.load_npy_dir(labels_dir,"*.npy") labels = labels.flatten() my_min = np.min(labels) my_max = np.max(labels) morph_enc = np.zeros((n_frames,n_latent)) vals = np.linspace(my_min, my_max, n_frames) delta = (vals[1] - vals[0]) * 0.5 for i in range(n_frames): val = vals[i] inds = np.where(np.logical_and(labels>=val-delta, labels<=val+delta))[0] for j in range(n_latent): x = np.mean(enc[inds,j]) morph_enc[i, j] = x #morph_enc = Variable(torch.from_numpy(morph_enc).type(torch.FloatTensor)) morph_enc = np.array(morph_enc) traj = recon_traj(morph_enc,net,ref_s.top,cm) rmsf = get_rmsf(traj) out_fn = os.path.join(morph_dir, "morph_0-1.pdb") traj.save_pdb(out_fn, bfactors=rmsf) def get_rmsf(traj): x_mean = traj.xyz.mean(axis=0) delta = traj.xyz - x_mean d2 = np.einsum('ijk,ijk->ij', delta, delta) p = 1.0*np.ones(len(traj)) / len(traj) msf = np.einsum('ij,i->j', d2, p) return np.sqrt(msf) def find_features(net,data_dir,nn_dir,clust_cents,inds,out_fn,num2plot=100): #Need to atom custom indices encs_dir = os.path.join(nn_dir,"encodings") encs = utils.load_npy_dir(encs_dir,"*.npy") encs = encs[clust_cents] cm = np.load(os.path.join(data_dir,"cm.npy")) top = md.load(os.path.join(data_dir,"master.pdb")) traj = recon_traj(encs,net,top.top,cm) print("trajectory calculated") all_pairs = list(itertools.product(inds, repeat=2)) distances = md.compute_distances(traj,all_pairs) labels_dir = os.path.join(nn_dir,"labels") labels = utils.load_npy_dir(labels_dir,"*.npy") labels = labels[clust_cents] n = len(inds) print(n, " distances being calculated") r_values = [] slopes = [] for i in np.arange(n*n): slope, intercept, r_value, p_value, std_err = stats.linregress(labels.flatten(),distances[:,i]) r_values.append(r_value) slopes.append(slope) r2_values = np.array(r_values)**2 corr_slopes = [] count = 0 print("Starting to write pymol file") f = open(os.path.join(nn_dir,out_fn), "w") for i in np.argsort(r2_values)[-num2plot:]: corr_slopes.append(slopes[i]) #print(slopes[i],r2_values[i],i) j,k = np.array(all_pairs)[i,:] jnum = top.top.atom(j).residue.resSeq jname = top.top.atom(j).name knum = top.top.atom(k).residue.resSeq kname = top.top.atom(k).name if slopes[i] < 0: f.write("distance dc%s, master and resi %s and name %s, master and resi %s and name %s\n" % (count,jnum,jname,knum,kname)) f.write("color red, dc%s\n" % count) f.write("hide label\n") else: f.write("distance df%s, master and resi %s and name %s, master and resi %s and name %s\n" % (count,jnum,jname,knum,kname)) f.write("color blue, df%s\n" % count) f.write("hide label\n") count+=1 f.close() ######################################################### # # # Extra analysis functions # # # ######################################################### def calc_auc(net_fn,out_fn,data,labels): net = pickle.load(open(net_fn, 'rb')) net.cpu() full_x = torch.from_numpy(data).type(torch.FloatTensor) if hasattr(net, "encode"): full_x = Variable(full_x.view(-1, 784).float()) pred_x, latents, pred_class = net(full_x) preds = pred_class.detach().numpy() else: full_x = Variable(full_x.view(-1, 3,32,32).float()) preds = net(full_x).detach().numpy() fpr, tpr, thresh = roc_curve(labels,preds) auc = roc_auc_score(labels,preds.flatten()) print("AUC: %f" % auc) #plt.figure() #lw = 2 #plt.plot(fpr, tpr, color='darkorange', # lw=lw, label='ROC curve (area = %f)' % auc) #plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--') #plt.xlim([0.0, 1.0]) #plt.ylim([0.0, 1.05]) #plt.xlabel('False Positive Rate') #plt.ylabel('True Positive Rate') #plt.title('Receiver operating characteristic example') #plt.legend(loc="lower right") #plt.savefig(out_fn) #plt.close() return auc, fpr, tpr def split_vars(d, vars): n = len(d) n_vars = len(vars) n_per_var = int(len(d)/n_vars) lst = {} for i in range(n_vars): v = vars[i] lst[v] = d[i*n_per_var:(i+1)*n_per_var] return lst def get_extrema(lst_lst): my_min = np.inf my_max = -np.inf for lst in lst_lst: my_min = np.min((my_min, np.min(lst))) my_max = np.max((my_max, np.max(lst))) return my_min, my_max def common_hist(lst_lst, labels, bins): my_min, my_max = get_extrema(lst_lst) n_lst = len(lst_lst) all_h = {} for i in range(n_lst): h, x = np.histogram(lst_lst[i], bins=bins, range=(my_min, my_max)) all_h[labels[i]] = h return all_h, x def calc_overlap(d1, d2, bins): n_feat = d1.shape[1] js = np.zeros(n_feat) ent1 = np.zeros(n_feat) ent2 = np.zeros(n_feat) for i in range(n_feat): h, x = common_hist([d1[:, i], d2[:, i]], ["d1", "d2"], bins) h1 = h["d1"] h2 = h["d2"] p1 = np.array(h1) / h1.sum() p2 = np.array(h2) / h2.sum() js[i] = infotheor.js_divergence(p1, p2) ent1[i] = infotheor.shannon_entropy(p1) ent2[i] = infotheor.shannon_entropy(p2) return js, ent1, ent2 def project(enc, lab, vars, i1, i2, bins, my_title, cutoff=0.8): subsample = 100 all_act_inds = np.where(lab>cutoff)[0] act_i1_mu = enc[all_act_inds, i1].mean() act_i1_std = enc[all_act_inds, i1].std() act_i2_mu = enc[all_act_inds, i2].mean() act_i2_std = enc[all_act_inds, i2].std() n_vars = len(vars) enc_dict = split_vars(enc, vars) lab_dict = split_vars(lab, vars) i1_dict = {} i2_dict = {} act_inds = {} for v in vars: i1_dict[v] = enc_dict[v][:, i1] i2_dict[v] = enc_dict[v][:, i2] act_inds[v] = np.where(lab_dict[v]>cutoff)[0] # i1_min, i1_max = get_extrema(i1_dict.values()) # i2_min, i2_max = get_extrema(i2_dict.values()) # drop outliers by only show data within const*std const = 3 i1_mu = enc[:, i1].mean() i1_std = enc[:, i1].std() i1_min = np.max((i1_mu-const*i1_std, enc[:, i1].min())) i1_max = np.min((i1_mu+const*i1_std, enc[:, i1].max())) i2_mu = enc[:, i2].mean() i2_std = enc[:, i2].std() i2_min = np.max((i2_mu-const*i2_std, enc[:, i2].min())) i2_max = np.min((i2_mu+const*i2_std, enc[:, i2].max())) # get min/max of z dim cmin = np.inf cmax = -np.inf for i in range(n_vars): v = vars[i] tmp, x, y = np.histogram2d(i1_dict[v], i2_dict[v], range=([i1_min, i1_max], [i2_min, i2_max]), bins=n_bins) tmp /= tmp.sum() h = np.zeros(tmp.shape) inds = np.where(tmp>0) h[inds] = -np.log(tmp[inds]) #inds = np.where(np.isnan(h)) #h[inds] = 0 cmin = np.min((cmin, h[inds].min())) cmax = np.max((cmax, h[inds].max())) height = 4 width = height*n_vars fig = figure(figsize=(width, height)) fig.suptitle(my_title) bins = 20 dot_size = 0.1 for i in range(n_vars): v = vars[i] ax = fig.add_subplot(1, n_vars, i+1, aspect='auto', xlim=x[[0, -1]], ylim=y[[0, -1]]) #scatter(i1_dict[v], i2_dict[v], s=dot_size, c='b', alpha=0.1) tmp, x, y = np.histogram2d(i1_dict[v], i2_dict[v], range=([i1_min, i1_max], [i2_min, i2_max]), bins=n_bins) tmp /= tmp.sum() h = cmax*np.ones(tmp.shape) inds = np.where(tmp>0) h[inds] = -np.log(tmp[inds]) h -= cmax delta_x = (x[1]-x[0])/2.0 delta_y = (y[1]-y[0])/2.0 #imshow(h, interpolation='bilinear', aspect='auto', origin='low', extent=[x[0]+delta_x, x[-1]+delta_x, y[0]+delta_y, y[-1]+delta_y], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r')) # transpose to put first dimension (i1) on x axis #imshow(h.T, interpolation='bilinear', aspect='auto', origin='low', extent=[y[0]+delta_y, y[-1]+delta_y, x[0]+delta_x, x[-1]+delta_x], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r')) imshow(h.T, interpolation='bilinear', aspect='auto', origin='low', extent=[x[0]+delta_x, x[-1]+delta_x, y[0]+delta_y, y[-1]+delta_y], vmin=cmin-cmax, vmax=0, cmap=get_cmap('Blues_r')) colorbar() lines = [] line_labels = [] for v2 in vars: i1_mu = i1_dict[v2].mean() i1_std = i1_dict[v2].std() i2_mu = i2_dict[v2].mean() i2_std = i2_dict[v2].std() #print(v, "x", i1_mu, i1_std) #print(v, "y", i2_mu, i2_std) line, _, _ = errorbar([i1_mu], [i2_mu], xerr=[i1_std], yerr=[i2_std], label=v2) lines.append(line) line_labels.append(v2) # inds = act_inds[v2] # if inds.shape[0] > subsample: # inds = inds[::subsample] # print(inds.shape) # if inds.shape[0] > 0: # scatter(i1_dict[v2][inds], i2_dict[v2][inds], s=dot_size, c='k') line, _, _ = errorbar([act_i1_mu], [act_i2_mu], xerr=[act_i1_std], yerr=[act_i2_std], label='act', ecolor='k', fmt='k') lines.append(line) line_labels.append('act') #legend() title(v) # scatter([0], [0], s=dot_size*10, c='k') # scatter([6], [0], s=dot_size*10, c='k') # scatter([6], [6], s=dot_size*10, c='k') fig.legend(lines, line_labels) show() def morph_conditional(nn_dir, data_dir, n_frames=10): net = pickle.load(open("%s/nn_best_polish.pkl" % nn_dir, 'rb')) net.cpu() pdb_fn = os.path.join(nn_dir, "master.pdb") ref_s = md.load(pdb_fn) n_atoms = ref_s.top.n_atoms uwm_fn = os.path.join(data_dir, "uwm.npy") uwm = np.load(uwm_fn) cm_fn = os.path.join(data_dir, "cm.npy") cm = np.load(cm_fn) enc = load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy") n_latent = int(enc.shape[1]) morph_dir = os.path.join(nn_dir, "morph") if not os.path.exists(morph_dir): os.mkdir(morph_dir) for i in range(n_latent): my_min, my_max = get_extrema([enc[:, i]]) print(i, my_min, my_max) morph_enc = np.zeros((n_frames, n_latent)) vals = np.linspace(my_min, my_max, n_frames) delta = (vals[1] - vals[0]) * 0.5 for j in range(n_frames): val = vals[j] # set each latent variable to most probable value given latent(ind) within delta of selected value inds = np.where(np.logical_and(enc[:,i]>=val-delta, enc[:,i]<=val+delta))[0] for k in range(n_latent): n, x = np.histogram(enc[inds, k], bins=20) offset = (x[1] - x[0]) * 0.5 morph_enc[j, k] = x[n.argmax()] + offset # fix ref latent variable to val morph_enc[j, i] = val morph_enc = Variable(torch.from_numpy(morph_enc).type(torch.FloatTensor)) try: outputs, labs = net.decode(morph_enc) except: print("single") outputs = net.decode(morph_enc) outputs = outputs.data.numpy() coords = whiten.apply_unwhitening(outputs, uwm, cm) print("shape", coords.shape) recon_trj = md.Trajectory(coords.reshape((n_frames, n_atoms, 3)), ref_s.top) out_fn = os.path.join(morph_dir, "m%d.pdb" % i) recon_trj.save(out_fn) def morph_cond_mean(nn_dir,data_dir,n_frames=10): net = pickle.load(open("%s/nn_best_polish.pkl" % nn_dir, 'rb')) net.cpu() pdb_fn = os.path.join(nn_dir, "master.pdb") ref_s = md.load(pdb_fn) n_atoms = ref_s.top.n_atoms uwm_fn = os.path.join(data_dir, "uwm.npy") uwm = np.load(uwm_fn) cm_fn = os.path.join(data_dir, "cm.npy") cm = np.load(cm_fn) enc = load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy") n_latent = int(enc.shape[1]) morph_dir = os.path.join(nn_dir, "morph_bin_mean") if not os.path.exists(morph_dir): os.mkdir(morph_dir) for i in range(n_latent): my_min, my_max = get_extrema([enc[:, i]]) print(i, my_min, my_max) morph_enc = np.zeros((n_frames, n_latent)) vals = np.linspace(my_min, my_max, n_frames) delta = (vals[1] - vals[0]) * 0.5 for j in range(n_frames): val = vals[j] # set each latent variable to most probable value given latent(ind) within delta of selected value inds = np.where(np.logical_and(enc[:,i]>=val-delta, enc[:,i]<=val+delta))[0] for k in range(n_latent): x = np.mean(enc[inds,k]) morph_enc[j, k] = x # fix ref latent variable to val morph_enc[j, i] = val morph_enc = Variable(torch.from_numpy(morph_enc).type(torch.FloatTensor)) traj = utils.recon_traj(morph_enc,net,ref_s.top,cm) rmsf = get_rmsf(traj) out_fn = os.path.join(outdir, "m%d.pdb" % i) traj.save_pdb(out_fn, bfactors=rmsf) def morph_std(nn_dir, data_dir, enc): outdir = os.path.join(nn_dir, "morph_std") utils.mkdir(outdir) n_frames = 10 net = pickle.load(open("%s/nn_best_polish.pkl" % nn_dir, 'rb')) net.cpu() pdb_fn = os.path.join(nn_dir, "master.pdb") ref_s = md.load(pdb_fn) n_atoms = ref_s.top.n_atoms cm_fn = os.path.join(data_dir, "cm.npy") cm = np.load(cm_fn) n_latent = int(enc.shape[1]) ave_enc = enc.mean(axis=0) std_enc = enc.std(axis=0) max_enc = enc.max(axis=0) min_enc = enc.min(axis=0) # want vary between mean +/- 2*std but not go out of range for i in range(n_latent): #my_min = np.max((ave_enc[i]-5*std_enc[i], min_enc[i])) #my_max = np.min((ave_enc[i]+5*std_enc[i], max_enc[i])) my_min = min_enc[i] my_max = max_enc[i] morph_enc = np.zeros((n_frames, n_latent)) + ave_enc morph_enc[:, i] = np.linspace(my_min, my_max, n_frames) traj = utils.recon_traj(morph_enc, net, ref_s.top, cm) rmsf = get_rmsf(traj) out_fn = os.path.join(outdir, "m%d.pdb" % i) traj.save_pdb(out_fn, bfactors=rmsf) def get_act_inact(nn_dir, data_dir, enc, labels): """Save most active/inactive sturctures with RMSDs from target less than 2 Angstroms.""" outdir = os.path.join(nn_dir, "act_and_inact") utils.mkdir(outdir) n_extreme = 1000 net = pickle.load(open("%s/nn_best_polish.pkl" % nn_dir, 'rb')) net.cpu() pdb_fn = os.path.join(nn_dir, "master.pdb") ref_s = md.load(pdb_fn) ca_inds = ref_s.top.select('name CA') n_atoms = ref_s.top.n_atoms cm_fn = os.path.join(data_dir, "cm.npy") cm = np.load(cm_fn) rmsd_cutoff = 0.2 rmsd_fn = os.path.join(nn_dir, "rmsd.npy") rmsd = np.load(rmsd_fn) good_inds = np.where(rmsd<rmsd_cutoff) enc = enc[good_inds] labels = labels[good_inds] inds = np.argsort(labels.flatten()) act_traj = utils.recon_traj(enc[inds[-n_extreme:]], net, ref_s.top, cm) out_fn = os.path.join(outdir, "active.xtc") act_traj.save(out_fn) for i in range(10): out_fn = os.path.join(outdir, "act%d.pdb" % i) act_traj[i].save(out_fn) act_traj = act_traj.atom_slice(ca_inds) act_rmsf = 10*get_rmsf(act_traj) out_fn = os.path.join(outdir, "act_rmsf.npy") np.save(out_fn, act_rmsf) inact_traj = utils.recon_traj(enc[inds[:n_extreme]], net, ref_s.top, cm) out_fn = os.path.join(outdir, "inactive.xtc") inact_traj.save(out_fn) for i in range(10): out_fn = os.path.join(outdir, "inact%d.pdb" % i) inact_traj[i].save(out_fn) inact_traj = inact_traj.atom_slice(ca_inds) inact_rmsf = 10*get_rmsf(inact_traj) out_fn = os.path.join(outdir, "inact_rmsf.npy") np.save(out_fn, inact_rmsf) #all_h, x = common_hist([act_rmsf, inact_rmsf], ['act', 'inact'], 20) fig = figure(figsize=(4, 8)) title #plot(x, all_h['act'], label='act') #plot(x, all_h['inact'], label='inact') res_nums = [] for r in act_traj.top.residues: res_nums.append(r.resSeq) ax = fig.add_subplot(211) plot(res_nums, act_rmsf, label='act') plot(res_nums, inact_rmsf, label='inact') legend() ax = fig.add_subplot(212) d = act_rmsf-inact_rmsf plot(res_nums, d, 'k') out_fn = os.path.join(outdir, "act_minus_inact.npy") np.save(out_fn, d) show() out_fn = os.path.join(outdir, "act_minus_inact.pdb") ref_s = ref_s.atom_slice(ca_inds) ref_s.save_pdb(out_fn, bfactors=d) print("rmsf delta extrema", d.min(), d.mean(), d.max()) def enc_corr(enc): n_latent = enc.shape[1] corr = [] for i in range(n_latent): for j in range(i+1, n_latent): c = pearsonr(enc[:,i], enc[:,j])[0] corr.append(c) return np.array(corr) def project_act(lab_v, vars, my_title): n_vars = len(vars) print(my_title) fig = figure(figsize=(4, 4)) fig.suptitle(my_title) for i in range(n_vars): v = vars[i] n, x = np.histogram(lab_v[v], range=(0, 1), bins=50) plot(x[:-1], n, label=v) print(v, lab_v[v].mean()) legend() show() def check_loss(nn_dir): i = 2 fn = os.path.join(nn_dir, "test_loss_%d.npy" % i) while os.path.exists(fn): d = np.load(fn) plot(d, label=str(i)) i += 1 fn = os.path.join(nn_dir, "test_loss_%d.npy" % i) fn = os.path.join(nn_dir, "test_loss_polish.npy") d = load(fn) plot(d, label='p') legend() show() def clust_encod(nn_dir, n_clusters, vars, lag_times,n_traj_per_var): msm_dir = os.path.join(nn_dir, "msm_%d" % n_clusters) utils.mkdir(msm_dir) enc = utils.load_npy_dir(os.path.join(nn_dir, "encodings"), "*npy") enc_v = split_vars(enc, vars) n_vars = len(vars) #n_traj_per_var = 5 clusters = cluster.hybrid.hybrid(enc, euc_dist, n_clusters=n_clusters, n_iters=1) # clusters.assignments and clusters.centers most relevant vars cluster_fn = os.path.join(msm_dir, "clusters.pkl") pickle.dump(clusters, open(cluster_fn, 'wb')) # assuming 5 traj of equal length per variant, divide into traj assigns = clusters.assignments.reshape((n_vars*n_traj_per_var, -1)) height = 4 width = height*n_vars fig = figure(figsize=(width, height)) fig.suptitle(nn_dir) for i in range(n_vars): v = vars[i] print("Getting impolied timescales for", v) v_assians = assigns[i*n_traj_per_var:(i+1)*n_traj_per_var] f = lambda c: msm.builders.normalize(c, prior_counts=1.0/n_clusters, calculate_eq_probs=True) imp_times = msm.implied_timescales(v_assians, lag_times, f) imp_fn = os.path.join(msm_dir, "%s_imp_norm.npy" % v) np.save(imp_fn, imp_times) ax = fig.add_subplot(1, n_vars, i+1, aspect='auto') for i, t in enumerate(lag_times): scatter(t*np.ones(imp_times.shape[1]), imp_times[i]) title(v) ax.set_yscale('log') markov_lag = 10 c = msm.assigns_to_counts(v_assians, 1, max_n_states=n_clusters) c_fn = os.path.join(msm_dir, "%s_c_raw_lag%s.npz" % (v, markov_lag)) scipy.sparse.save_npz(c_fn, c) C, T, p = msm.builders.normalize(c, prior_counts=1.0/n_clusters, calculate_eq_probs=True) p_fn = os.path.join(msm_dir, "%s_p_norm_lag%d.npy" % (v, markov_lag)) np.save(p_fn, p) T_fn = os.path.join(msm_dir, "%s_T_norm_lag%d.npy" % (v, markov_lag)) np.save(T_fn, T) C_fn = os.path.join(msm_dir, "%s_C_norm_lag%d.npy" % (v, markov_lag)) np.save(C_fn, C) out_fn = os.path.join(msm_dir, "imp_times.png") savefig(out_fn) show()