diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 8366eb5edc..491b9ff9be 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -53,6 +53,9 @@ def param_placeholder(self, prepend_shape, name=None): def sample_placeholder(self, prepend_shape, name=None): return tf.placeholder(dtype=self.sample_dtype(), shape=prepend_shape+self.sample_shape(), name=name) + def __eq__(self, other): + return (type(self) == type(other)) and (self.__dict__ == other.__dict__) + class CategoricalPdType(PdType): def __init__(self, ncat): self.ncat = ncat