Skip to content

Commit

Permalink
Allow black to clean up (and uglify) my code.
Browse files Browse the repository at this point in the history
Put a long format string in parentheses to make a line shorter and
satisfy flake8.
  • Loading branch information
kdolum committed Nov 16, 2023
1 parent 138689b commit 92bbdf6
Showing 1 changed file with 52 additions and 41 deletions.
93 changes: 52 additions & 41 deletions PTMCMCSampler/PTMCMCSampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(
resume=False,
seed=None,
):

# MPI initialization
self.comm = comm
self.MPIrank = self.comm.Get_rank()
Expand Down Expand Up @@ -204,12 +203,12 @@ def initialize(
self.neff = neff
self.tstart = 0

N = int(maxIter / thin) + 1 # first sample + those we generate
N = int(maxIter / thin) + 1 # first sample + those we generate

self._lnprob = np.zeros(N)
self._lnlike = np.zeros(N)
self._chain = np.zeros((N, self.ndim))
self.ind_next_write = 0 # Next index in these arrays to write out
self.ind_next_write = 0 # Next index in these arrays to write out
self.naccepted = 0
self.swapProposed = 0
self.nswap_accepted = 0
Expand Down Expand Up @@ -292,15 +291,27 @@ def initialize(
print("Resuming run from chain file {0}".format(self.fname))
try:
self.resumechain = np.loadtxt(self.fname)
self.resumeLength = self.resumechain.shape[0] # Number of samples read from old chain
self.resumeLength = self.resumechain.shape[0] # Number of samples read from old chain
except ValueError as error:
print("Reading old chain files failed with error", error)
raise Exception("Couldn't read old chain to resume")
self._chainfile = open(self.fname, "a")
if (self.isave != self.thin and # This special case is always OK
self.resumeLength % (self.isave/self.thin) != 1): # Initial sample plus blocks of isave/thin
raise Exception("Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}".format(self.resumeLength, self.isave//self.thin))
print("Resuming with", self.resumeLength, "samples from file representing", (self.resumeLength-1)*self.thin+1, "original samples")
if (
self.isave != self.thin
and self.resumeLength % (self.isave / self.thin) != 1 # This special case is always OK
): # Initial sample plus blocks of isave/thin
raise Exception(
(
"Old chain has {0} rows, which is not the initial sample plus a multiple of isave/thin = {1}"
).format(self.resumeLength, self.isave // self.thin)
)
print(
"Resuming with",
self.resumeLength,
"samples from file representing",
(self.resumeLength - 1) * self.thin + 1,
"original samples",
)
else:
self._chainfile = open(self.fname, "w")
self._chainfile.close()
Expand Down Expand Up @@ -330,7 +341,6 @@ def writeOutput(self, iter):
Write chains and covariance matrix. Called every isave on samples or at end.
"""
if iter // self.thin >= self.ind_next_write:

if self.writeHotChains or self.MPIrank == 0:
self._writeToFile(iter)

Expand All @@ -339,21 +349,24 @@ def writeOutput(self, iter):
np.save(self.outDir + "/cov.npy", self.cov)

if self.MPIrank == 0 and self.verbose:
if iter > 0: sys.stdout.write("\r")
percent = iter / self.Niter * 100 # Percent of total work finished
if iter > 0:
sys.stdout.write("\r")
percent = iter / self.Niter * 100 # Percent of total work finished
acceptance = self.naccepted / iter if iter > 0 else 0
elapsed = time.time() - self.tstart
if self.resume:
# Percentage of new work done
percentnew = ((iter - self.resumeLength*self.thin)
/ (self.Niter - self.resumeLength*self.thin) * 100)
percentnew = (
(iter - self.resumeLength * self.thin) / (self.Niter - self.resumeLength * self.thin) * 100
)
sys.stdout.write(
"Finished %2.2f percent (%2.2f percent of new work) in %f s Acceptance rate = %g"
% (percent, percentnew, elapsed, acceptance))
% (percent, percentnew, elapsed, acceptance)
)
else:
sys.stdout.write("Finished %2.2f percent in %f s Acceptance rate = %g"
% (percent, elapsed, acceptance)
)
sys.stdout.write(
"Finished %2.2f percent in %f s Acceptance rate = %g" % (percent, elapsed, acceptance)
)
sys.stdout.flush()

def sample(
Expand Down Expand Up @@ -416,12 +429,14 @@ def sample(
elif maxIter is None and self.MPIrank == 0:
maxIter = Niter

if (isave % thin != 0):
if isave % thin != 0:
raise ValueError("isave = %d is not a multiple of thin = %d" % (isave, thin))

if (Niter % thin != 0):
print("Niter = %d is not a multiple of thin = %d. The last %d samples will be lost"
% (Niter, thin, Niter % thin))
if Niter % thin != 0:
print(
"Niter = %d is not a multiple of thin = %d. The last %d samples will be lost"
% (Niter, thin, Niter % thin)
)

# set up arrays to store lnprob, lnlike and chain
# if picking up from previous run, don't re-initialize
Expand Down Expand Up @@ -462,12 +477,10 @@ def sample(
lp = self.logp(p0)

if lp == float(-np.inf):

lnprob0 = -np.inf
lnlike0 = -np.inf

else:

lnlike0 = self.logl(p0)
lnprob0 = 1 / self.temp * lnlike0 + lp

Expand All @@ -479,7 +492,7 @@ def sample(

# start iterations
iter = i0

runComplete = False
Neff = 0
while runComplete is False:
Expand All @@ -502,17 +515,17 @@ def sample(

# rank 0 decides whether to stop
if self.MPIrank == 0:
if iter >= self.Niter: # stop if reached maximum number of iterations
if iter >= self.Niter: # stop if reached maximum number of iterations
message = "\nRun Complete"
runComplete = True
elif int(Neff) > self.neff: # stop if reached maximum number of iterations
elif int(Neff) > self.neff: # stop if reached maximum number of iterations
message = "\nRun Complete with {0} effective samples".format(int(Neff))
runComplete = True

runComplete = self.comm.bcast(runComplete, root=0) # rank 0 tells others whether to stop
runComplete = self.comm.bcast(runComplete, root=0) # rank 0 tells others whether to stop

if runComplete:
self.writeOutput(iter) # Possibly write partial block
self.writeOutput(iter) # Possibly write partial block
if self.MPIrank == 0 and self.verbose:
print(message)

Expand Down Expand Up @@ -577,11 +590,15 @@ def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):

# if resuming, just use previous chain points. Use each one thin times to compensate for
# thinning when they were written out
if self.resume and self.resumeLength > 0 and iter < self.resumeLength*self.thin:
p0, lnlike0, lnprob0 = self.resumechain[iter//self.thin, :-4], self.resumechain[iter//self.thin, -3], self.resumechain[iter//self.thin, -4]
if self.resume and self.resumeLength > 0 and iter < self.resumeLength * self.thin:
p0, lnlike0, lnprob0 = (
self.resumechain[iter // self.thin, :-4],
self.resumechain[iter // self.thin, -3],
self.resumechain[iter // self.thin, -4],
)

# update acceptance counter
self.naccepted = iter * self.resumechain[iter//self.thin, -2]
self.naccepted = iter * self.resumechain[iter // self.thin, -2]
else:
y, qxy, jump_name = self._jump(p0, iter)
self.jumpDict[jump_name][0] += 1
Expand All @@ -590,18 +607,15 @@ def PTMCMCOneStep(self, p0, lnlike0, lnprob0, iter):
lp = self.logp(y)

if lp == -np.inf:

newlnprob = -np.inf

else:

newlnlike = self.logl(y)
newlnprob = 1 / self.temp * newlnlike + lp

# hastings step
diff = newlnprob - lnprob0 + qxy
if diff > np.log(self.stream.random()):

# accept jump
p0, lnlike0, lnprob0 = y, newlnlike, newlnprob

Expand Down Expand Up @@ -710,25 +724,24 @@ def _writeToFile(self, iter):

self._chainfile = open(self.fname, "a+")
# index 0 is the initial element. So after 10*thin iterations we need to write elements 1..10
write_end = iter // self.thin + 1 # First element not to write.
write_end = iter // self.thin + 1 # First element not to write.
for ind in range(self.ind_next_write, write_end):
pt_acc = 1
if self.MPIrank < self.nchain - 1 and self.swapProposed != 0:
pt_acc = self.nswap_accepted / self.swapProposed

self._chainfile.write("\t".join(["%22.22f" % (self._chain[ind, kk]) for kk in range(self.ndim)]))
self._chainfile.write(
"\t%f\t%f\t%f\t%f\n" % (self._lnprob[ind], self._lnlike[ind],
self.naccepted / iter if iter > 0 else 0, pt_acc)
"\t%f\t%f\t%f\t%f\n"
% (self._lnprob[ind], self._lnlike[ind], self.naccepted / iter if iter > 0 else 0, pt_acc)
)
self._chainfile.close()
self.ind_next_write = write_end # Ready for next write
self.ind_next_write = write_end # Ready for next write

# write jump statistics files ####

# only for T=1 chain
if self.MPIrank == 0:

# first write file contaning jump names and jump rates
fout = open(self.outDir + "/jumps.txt", "w")
njumps = len(self.propCycle)
Expand Down Expand Up @@ -765,7 +778,6 @@ def _updateRecursive(self, iter, mem):
diff = np.zeros(ndim)
it += 1
for jj in range(ndim):

diff[jj] = self._AMbuffer[ii, jj] - self.mu[jj]
self.mu[jj] += diff[jj] / it

Expand Down Expand Up @@ -956,7 +968,6 @@ def DEJump(self, x, iter, beta):
scale = self.stream.random() * 2.4 / np.sqrt(2 * ndim) * np.sqrt(1 / beta)

for ii in range(ndim):

# jump size
sigma = self._DEbuffer[mm, self.groups[jumpind][ii]] - self._DEbuffer[nn, self.groups[jumpind][ii]]

Expand Down

0 comments on commit 92bbdf6

Please sign in to comment.