Skip to content

Commit

Permalink
Remove unnecessary use of xla_client.OpMetadata class.
Browse files Browse the repository at this point in the history
We create this object and immediately turn it into a different object. We can cut out a step here!

PiperOrigin-RevId: 718023353
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Jan 21, 2025
1 parent 79bd72e commit afb750c
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,11 +1293,11 @@ def full_lower(self):
def _make_op_metadata(primitive: core.Primitive,
params: dict, *,
source_info: source_info_util.SourceInfo,
) -> xla_client.OpMetadata:
) -> xla_data_pb2.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ core.str_eqn_compact(primitive, params))
frame = source_info_util.user_frame(source_info)
return xla_client.OpMetadata(
return xla_data_pb2.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=mlir.get_canonical_source_file(
Expand Down Expand Up @@ -1380,14 +1380,8 @@ def invoke_impl() -> TfVal:

with tf.name_scope(_sanitize_scope_name(scope)):
if _thread_local_state.include_xla_op_metadata:
op_metadata = _make_op_metadata(primitive, params,
source_info=source_info_util.current())
op_metadata_proto = xla_data_pb2.OpMetadata(
op_type=op_metadata.op_type,
op_name=op_metadata.op_name,
source_file=op_metadata.source_file,
source_line=op_metadata.source_line
)
op_metadata_proto = _make_op_metadata(
primitive, params, source_info=source_info_util.current())
with tf_ops.get_default_graph()._attr_scope(
{"_XlaOpMetadata": attr_value_pb2.AttrValue(
s=op_metadata_proto.SerializeToString())}):
Expand Down

0 comments on commit afb750c

Please sign in to comment.