Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOSA] Update tosa.cast check according to TOSA v1.0 spec #3948

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

justin-ngo-arm
Copy link
Contributor

Change-Id: I41209c698a694bca57ebf49ed3608cf89a0d8ba8

@sjarus sjarus requested a review from sahas3 January 8, 2025 19:27
@justin-ngo-arm
Copy link
Contributor Author

I'm quoting my comment regarding unifying tosa.cast usage within TorchToTOSA from #3949:

After some discussion with @sjarus, I've decided to combine tosaCastTensorToType and promoteType together into one function wrapped around the tosa.cast op creation process. The combined function will still be called tosaCastTensorToType as this name is more descriptive about what this function does. I plan to replace all usages of tosa.cast creation with this "new" function to make everything uniform.

The following are the differences between the current promoteType and tosaCastTensorToType functions:

  • promoteType checks if the input and output types are same before performing the cast, which will prevent unnecessary cast operation if both types are the same.
  • tosaCastTensorToType has the cast validity check.

My plan is to take the same-type check from promoteType and add it to tosaCastTensorToType, then remove the promoteType function altogether. As for the cast validity check, I will tighten it more so that it will strictly follow the TOSA v1.0 spec (i.e. removing all I64 and F64 casting as they are not allowed in TOSA). I will update that in #3948. However, with this tightening, many e2e tests will fail. This is not desirable as these operations, although are not congruent with the spec,are still permissible. Therefore, after tightening up the checkValidityOfCast function, I will leave it out of the tosaCastTensorToType function for now and enable it later with a potential --strict mode that is defaulted to off. TOSA validation should flag illegal constructs based on each profile anyway, so this --strict mode is just another guard before that.

I will begin the new tosaCastTensorToType work in a separate PR from #3948, so that it is easier to keep track of changes and progress.

* Update checkValidityOfCast function for tosa.cast according to the
  latest TOSA v1.0 spec: https://www.mlplatform.org/tosa/tosa_spec.html#_cast
* Clean up some dead code in TorchToTosa

Signed-off-by: Justin Ngo <[email protected]>
Change-Id: I41209c698a694bca57ebf49ed3608cf89a0d8ba8
"FullModuleFalsePinMemory_basic",
"FullModuleInt2D_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification: all these tests are now passing because checkValidityOfCast is not enforced?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants