-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathModel_for_native
75 lines (61 loc) · 2.88 KB
/
Model_for_native
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def rmsd_dist_burn(distances, burn, sample_nr=10, target_accept_prob=0.4):
"""
The function runs NUTS sampler based on the specific model for sampling protein
structure with given pairwise distances.
distance: number of random distances to be additionally restraint;
burn: warm up size;
sample_nr: number of samples to run algorithm;
target_accept_prob: target acceptance probability, NUTS Sampler parameter;
Returns: all structure average RMSD and separate values, fixed 3 first coordinates
average RMSD and separate values, time that it took each iteration to run.
"""
# Sample sample_nr times
for j in range(sample_nr):
# Randomly sample number (distances) of distances from native_coords
dist_nr = []
points = []
i=1
while i <= distances: # distances
p1 = sample_distance(native_coords)
if p1[1] in points or p1[1][::-1] in points:
continue
else:
dist_nr.append(p1)
points.append(p1[1])
i+=1
def model(N=20):
# Sample N random points according to a Normal distribution
M_last = pyro.sample('M', dist.Normal(0, 20).expand_by([N-3,3]).to_event(1))
M=torch.cat((M_first, M_last))
# Make sure bond distances are around 3.8 Å
for i in pyro.plate('bonds', N-1):
bond = torch.dist(M[i], M[i+1])
bond_obs = pyro.sample('bond_%i' % i, dist.Normal(bond, 0.001), obs=torch.tensor(3.8))
# Add a distance restraints:
## add a distance restraint between first and last point
d = torch.dist(M[0], M[N-1])
d_obs = pyro.sample('d_obs', dist.Normal(d, 0.001), obs=torch.dist(native_coords_t[0], native_coords_t[N-1]))
## others
all_dist = []
for i in pyro.plate('dist_nr', distances):
d = torch.dist(M[dist_nr[i][1][0]], M[dist_nr[i][1][1]])
all_dist.append(pyro.sample('d%s_obs' % i, dist.Normal(d, 0.001), obs=dist_nr[i][0]))
# Nr samples
S=1000
# Nr samples burn-in
B=burn
start = time.time()
# Do NUTS sampling
nuts_kernel = NUTS(model, adapt_step_size=True, target_accept_prob=target_accept_prob)
mcmc_sampler = MCMC(nuts_kernel, num_samples=S, warmup_steps=B)
posterior = mcmc_sampler.run()
# Get the last sampled points
samples = get_samples(posterior, 'M')
M_last=samples[S-1]
M = torch.cat((M_first, M_last)) # Add fixed first 3 coordinates
# for plotting return:
#return mean_rmsd, rmsd_all, mean_rmsd_first3, rmsd_first3, times
# or return samples for pdb file:
return M
# Save to PDB file
save_M(rmsd_dist_burn(distances=190, sample_nr=1, burn=60), 'test.pdb')