diff --git a/edward/inferences/metropolis_hastings.py b/edward/inferences/metropolis_hastings.py index 258400f9b..88abcee1a 100644 --- a/edward/inferences/metropolis_hastings.py +++ b/edward/inferences/metropolis_hastings.py @@ -12,6 +12,7 @@ try: from edward.models import Uniform + from tensorflow.contrib.bayesflow.metropolis_hastings import evolve except Exception as e: raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) @@ -64,6 +65,19 @@ def __init__(self, latent_vars, proposal_vars, data=None): def initialize(self, *args, **kwargs): kwargs['auto_transform'] = False + + # TODO In general, each latent variable has arbitrary shape and + # dtype. We cannot simply batch them into a single tf.Tensor with + # an extra dimension. + initial_sample = tf.stack([tf.gather(qz.params, 0) + for qz in six.itervalues(self.latent_vars)]) + self._state = tf.Variable(initial_sample, trainable=False, name="state") + self._state_log_density = tf.Variable( + self._log_joint(initial_sample), + trainable=False, name="state_log_density") + self._log_accept_ratio = tf.Variable( + tf.zeros_like(self._state_log_density.initialized_value()), + trainable=False, name="log_accept_ratio") return super(MetropolisHastings, self).initialize(*args, **kwargs) def build_update(self): @@ -80,9 +94,75 @@ def build_update(self): The updates assume each Empirical random variable is directly parameterized by `tf.Variable`s. """ - old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0)) - for z, qz in six.iteritems(self.latent_vars)} - old_sample = OrderedDict(old_sample) + old_state = self._state + forward_step = evolve(self._state, + self._state_log_density, + self._log_accept_ratio, + self._log_density, + self._proposal_fn, + n_steps=1) + assign_ops = [forward_step] + + with tf.control_dependencies([forward_step]): + # Update Empirical random variables. + for state, qz in zip(tf.unstack(self._state), + six.itervalues(self.latent_vars)): + variable = qz.get_variables()[0] + assign_ops.append(tf.scatter_update(variable, self.t, state)) + + # Increment n_accept (if accepted). + # TODO old_state might always be same. It would be great if we + # could more naturally get the acceptance rate from ``evolve``. + is_proposal_accepted = tf.where( + tf.reduce_any(tf.not_equal(old_state, self._state)), 1, 0) + assign_ops.append(self.n_accept.assign_add(is_proposal_accepted)) + + return tf.group(*assign_ops) + + def _log_joint(self, state): + """Utility function to calculate model's log joint density, + log p(x, z), for inputs z (and fixed data x). + + Args: + state: tf.Tensor. + """ + scope = self._scope + tf.get_default_graph().unique_name("sample") + # Form dictionary in order to replace conditioning on prior or + # observed variable with conditioning on a specific value. + # TODO verify ordering is preserved + dict_swap = {z: sample for z, sample in + zip(six.iterkeys(self.latent_vars), state)} + for x, qx in six.iteritems(self.data): + if isinstance(x, RandomVariable): + if isinstance(qx, RandomVariable): + qx_copy = copy(qx, scope=scope) + dict_swap[x] = qx_copy.value() + else: + dict_swap[x] = qx + + log_joint = 0.0 + for z in six.iterkeys(self.latent_vars): + z_copy = copy(z, dict_swap, scope=scope) + log_joint += tf.reduce_sum(z_copy.log_prob(dict_swap[z])) + + for x in six.iterkeys(self.data): + if isinstance(x, RandomVariable): + x_copy = copy(x, dict_swap, scope=scope) + log_joint += tf.reduce_sum(x_copy.log_prob(dict_swap[x])) + + return log_joint + + def proposal_fn(state): + """Utility function to propose new state, + znew ~ g(znew | zold) for inputs zold, and return the log density + ratio of log g(znew | zold) - log g(zold | znew). + + Args: + state: tf.Tensor. + """ + # TODO verify ordering is preserved + old_sample = {z: sample for z, sample in + zip(six.iterkeys(self.latent_vars), state)} # Form dictionary in order to replace conditioning on prior or # observed variable with conditioning on a specific value. @@ -99,7 +179,6 @@ def build_update(self): dict_swap_old.update(old_sample) base_scope = tf.get_default_graph().unique_name("inference") + '/' scope_old = base_scope + 'old' - scope_new = base_scope + 'new' # Draw proposed sample and calculate acceptance ratio. new_sample = old_sample.copy() # copy to ensure same order @@ -114,6 +193,7 @@ def build_update(self): dict_swap_new = dict_swap.copy() dict_swap_new.update(new_sample) + scope_new = base_scope + 'new' for z, proposal_z in six.iteritems(self.proposal_vars): # Build proposal g(zold | znew). @@ -121,42 +201,6 @@ def build_update(self): # Increment ratio. ratio -= tf.reduce_sum(proposal_zold.log_prob(dict_swap_old[z])) - for z in six.iterkeys(self.latent_vars): - # Build priors p(znew) and p(zold). - znew = copy(z, dict_swap_new, scope=scope_new) - zold = copy(z, dict_swap_old, scope=scope_old) - # Increment ratio. - ratio += tf.reduce_sum(znew.log_prob(dict_swap_new[z])) - ratio -= tf.reduce_sum(zold.log_prob(dict_swap_old[z])) - - for x in six.iterkeys(self.data): - if isinstance(x, RandomVariable): - # Build likelihoods p(x | znew) and p(x | zold). - x_znew = copy(x, dict_swap_new, scope=scope_new) - x_zold = copy(x, dict_swap_old, scope=scope_old) - # Increment ratio. - ratio += tf.reduce_sum(x_znew.log_prob(dict_swap[x])) - ratio -= tf.reduce_sum(x_zold.log_prob(dict_swap[x])) - - # Accept or reject sample. - u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype), - high=tf.constant(1.0, dtype=ratio.dtype)).sample() - accept = tf.log(u) < ratio - sample_values = tf.cond(accept, lambda: list(six.itervalues(new_sample)), - lambda: list(six.itervalues(old_sample))) - if not isinstance(sample_values, list): - # `tf.cond` returns tf.Tensor if output is a list of size 1. - sample_values = [sample_values] - - sample = {z: sample_value for z, sample_value in - zip(six.iterkeys(new_sample), sample_values)} - - # Update Empirical random variables. - assign_ops = [] - for z, qz in six.iteritems(self.latent_vars): - variable = qz.get_variables()[0] - assign_ops.append(tf.scatter_update(variable, self.t, sample[z])) - - # Increment n_accept (if accepted). - assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0))) - return tf.group(*assign_ops) + # TODO verify ordering is preserved + new_sample = tf.stack(list(six.itervalues(new_sample))) + return (new_sample, ratio)