Skip to content

Commit

Permalink
Excludes kernel when inferring the common input dtype to multitask GP
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 714988845
  • Loading branch information
Googler authored and tensorflower-gardener committed Jan 13, 2025
1 parent fe99149 commit f1dd1c7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ multi_substrate_py_test(
srcs = ["multitask_gaussian_process_test.py"],
shard_count = 3,
deps = [
":multitask_gaussian_process",
":multitask_gaussian_process_regression_model",
# absl/testing:parameterized dep,
# numpy dep,
# tensorflow dep,
"//tensorflow_probability/python/distributions:gaussian_process",
"//tensorflow_probability/python/experimental/distributions:multitask_gaussian_process",
"//tensorflow_probability/python/experimental/distributions:multitask_gaussian_process_regression_model",
"//tensorflow_probability/python/experimental/psd_kernels:multitask_kernel",
"//tensorflow_probability/python/internal:test_util",
"//tensorflow_probability/python/math/psd_kernels:exponentiated_quadratic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,11 @@ def __init__(self,
parameters = dict(locals())
with tf.name_scope(name) as name:
input_dtype = dtype_util.common_dtype(
dict(
kernel=kernel,
index_points=index_points),
dict(index_points=index_points),
dtype_hint=nest_util.broadcast_structure(
kernel.feature_ndims, tf.float32))
kernel.feature_ndims, tf.float32
),
)

# If the input dtype is non-nested float, we infer a single dtype for the
# input and the float parameters, which is also the dtype of the MTGP's
Expand Down

0 comments on commit f1dd1c7

Please sign in to comment.