From 88c1c474705e2bd168f1110ff2561b3de1a38ecc Mon Sep 17 00:00:00 2001 From: Vivek Miglani Date: Mon, 30 Dec 2024 16:00:41 -0800 Subject: [PATCH] Fix progress pyre fixme issues Summary: Fixing unresolved pyre fixme issues in corresponding file Differential Revision: D67725994 --- captum/_utils/progress.py | 126 +++++++++++------- captum/influence/_core/tracincp.py | 29 ++-- .../_core/tracincp_fast_rand_proj.py | 22 ++- captum/influence/_utils/common.py | 13 +- 4 files changed, 126 insertions(+), 64 deletions(-) diff --git a/captum/_utils/progress.py b/captum/_utils/progress.py index 2e025006c3..e00c2fcc7c 100644 --- a/captum/_utils/progress.py +++ b/captum/_utils/progress.py @@ -3,15 +3,34 @@ # pyre-strict import sys +import typing import warnings from time import time -from typing import Any, cast, Iterable, Literal, Optional, Sized, TextIO +from types import TracebackType +from typing import ( + Any, + Callable, + cast, + Generic, + Iterable, + Iterator, + Literal, + Optional, + Sized, + TextIO, + Type, + TypeVar, + Union, +) try: from tqdm.auto import tqdm except ImportError: tqdm = None +T = TypeVar("T") +IterableType = TypeVar("IterableType") + class DisableErrorIOWrapper(object): def __init__(self, wrapped: TextIO) -> None: @@ -21,15 +40,13 @@ def __init__(self, wrapped: TextIO) -> None: """ self._wrapped = wrapped - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(self._wrapped, name) @staticmethod - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def _wrapped_run(func, *args, **kwargs): + def _wrapped_run( + func: Callable[..., T], *args: object, **kwargs: object + ) -> Union[T, None]: try: return func(*args, **kwargs) except OSError as e: @@ -38,19 +55,16 @@ def _wrapped_run(func, *args, **kwargs): except ValueError as e: if "closed" not in str(e): raise + return None - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def write(self, *args, **kwargs): + def write(self, *args: object, **kwargs: object) -> Optional[int]: return self._wrapped_run(self._wrapped.write, *args, **kwargs) - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def flush(self, *args, **kwargs): + def flush(self, *args: object, **kwargs: object) -> None: return self._wrapped_run(self._wrapped.flush, *args, **kwargs) -class NullProgress: +class NullProgress(Iterable[IterableType]): """Passthrough class that implements the progress API. This class implements the tqdm and SimpleProgressBar api but @@ -61,27 +75,28 @@ class NullProgress: def __init__( self, - # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. - iterable: Optional[Iterable] = None, + iterable: Optional[Iterable[IterableType]] = None, *args: Any, **kwargs: Any, ) -> None: del args, kwargs self.iterable = iterable - def __enter__(self) -> "NullProgress": + def __enter__(self) -> "NullProgress[IterableType]": return self - # pyre-fixme[2]: Parameter must be annotated. - def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]: + def __exit__( + self, + exc_type: Union[Type[BaseException], None], + exc_value: Union[BaseException, None], + exc_traceback: Union[TracebackType, None], + ) -> Literal[False]: return False - # pyre-fixme[3]: Return type must be annotated. - def __iter__(self): + def __iter__(self) -> Iterator[IterableType]: if not self.iterable: return - # pyre-fixme[16]: `Optional` has no attribute `__iter__`. - for it in self.iterable: + for it in cast(Iterable[IterableType], self.iterable): yield it def update(self, amount: int = 1) -> None: @@ -91,11 +106,10 @@ def close(self) -> None: pass -class SimpleProgress: +class SimpleProgress(Iterable[IterableType]): def __init__( self, - # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. - iterable: Optional[Iterable] = None, + iterable: Optional[Iterable[IterableType]] = None, desc: Optional[str] = None, total: Optional[int] = None, file: Optional[TextIO] = None, @@ -117,34 +131,33 @@ def __init__( self.desc = desc - # pyre-fixme[9]: file has type `Optional[TextIO]`; used as - # `DisableErrorIOWrapper`. - file = DisableErrorIOWrapper(file if file else sys.stderr) - cast(TextIO, file) - self.file = file + file_wrapper = DisableErrorIOWrapper(file if file else sys.stderr) + self.file: DisableErrorIOWrapper = file_wrapper self.mininterval = mininterval self.last_print_t = 0.0 self.closed = False self._is_parent = False - def __enter__(self) -> "SimpleProgress": + def __enter__(self) -> "SimpleProgress[IterableType]": self._is_parent = True self._refresh() return self - # pyre-fixme[2]: Parameter must be annotated. - def __exit__(self, exc_type, exc_value, exc_traceback) -> Literal[False]: + def __exit__( + self, + exc_type: Union[Type[BaseException], None], + exc_value: Union[BaseException, None], + exc_traceback: Union[TracebackType, None], + ) -> Literal[False]: self.close() return False - # pyre-fixme[3]: Return type must be annotated. - def __iter__(self): + def __iter__(self) -> Iterator[IterableType]: if self.closed or not self.iterable: return self._refresh() - # pyre-fixme[16]: `Optional` has no attribute `__iter__`. - for it in self.iterable: + for it in cast(Iterable[IterableType], self.iterable): yield it self.update() self.close() @@ -153,9 +166,7 @@ def _refresh(self) -> None: progress_str = self.desc + ": " if self.desc else "" if self.total: # e.g., progress: 60% 3/5 - # pyre-fixme[58]: `//` is not supported for operand types `int` and - # `Optional[int]`. - progress_str += f"{100 * self.cur // self.total}% {self.cur}/{self.total}" + progress_str += f"{100 * self.cur // cast(int, self.total)}% {self.cur}/{cast(int, self.total)}" else: # e.g., progress: ..... progress_str += "." * self.cur @@ -179,18 +190,39 @@ def close(self) -> None: self.closed = True -# pyre-fixme[3]: Return type must be annotated. +@typing.overload +def progress( + iterable: None = None, + desc: Optional[str] = None, + total: Optional[int] = None, + use_tqdm: bool = True, + file: Optional[TextIO] = None, + mininterval: float = 0.5, + **kwargs: object, +) -> Union[SimpleProgress[None], tqdm]: ... + + +@typing.overload +def progress( + iterable: Iterable[IterableType], + desc: Optional[str] = None, + total: Optional[int] = None, + use_tqdm: bool = True, + file: Optional[TextIO] = None, + mininterval: float = 0.5, + **kwargs: object, +) -> Union[SimpleProgress[IterableType], tqdm]: ... + + def progress( - # pyre-fixme[24]: Generic type `Iterable` expects 1 type parameter. - iterable: Optional[Iterable] = None, + iterable: Optional[Iterable[IterableType]] = None, desc: Optional[str] = None, total: Optional[int] = None, use_tqdm: bool = True, file: Optional[TextIO] = None, mininterval: float = 0.5, - # pyre-fixme[2]: Parameter must be annotated. - **kwargs, -): + **kwargs: object, +) -> Union[SimpleProgress[IterableType], tqdm]: # Try to use tqdm is possible. Fall back to simple progress print if tqdm and use_tqdm: return tqdm( diff --git a/captum/influence/_core/tracincp.py b/captum/influence/_core/tracincp.py index f36d383237..ef8104cb97 100644 --- a/captum/influence/_core/tracincp.py +++ b/captum/influence/_core/tracincp.py @@ -6,7 +6,18 @@ import warnings from abc import abstractmethod from os.path import join -from typing import Any, Callable, Iterator, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Iterable, + Iterator, + List, + Optional, + Tuple, + Type, + Union, +) import torch from captum._utils.av import AV @@ -1033,10 +1044,12 @@ def _influence( inputs = _format_inputs_dataset(inputs) train_dataloader = self.train_dataloader - + data_iterable: Union[Iterable[Tuple[object, ...]], DataLoader] = ( + train_dataloader + ) if show_progress: - train_dataloader = progress( - train_dataloader, + data_iterable = progress( + cast(Iterable[Tuple[object, ...]], train_dataloader), desc=( f"Using {self.get_name()} to compute " "influence for training batches" @@ -1053,7 +1066,7 @@ def _influence( return torch.cat( [ self._influence_batch_tracincp(inputs_checkpoint_jacobians, batch) - for batch in train_dataloader + for batch in data_iterable ], dim=1, ) @@ -1250,7 +1263,7 @@ def get_checkpoint_contribution(checkpoint: str) -> Tensor: # the same) checkpoint_contribution = [] - _inputs = inputs + _inputs: Union[DataLoader, Iterable[Tuple[Tensor, ...]]] = inputs # If `show_progress` is true, create an inner progress bar that keeps track # of how many batches have been processed for the current checkpoint if show_progress: @@ -1266,8 +1279,8 @@ def get_checkpoint_contribution(checkpoint: str) -> Tensor: for batch in _inputs: layer_jacobians = self._basic_computation_tracincp( - batch[0:-1], - batch[-1], + cast(Tuple[Tensor, ...], batch)[0:-1], + cast(Tuple[Tensor, ...], batch)[-1], self.loss_fn, self.reduction_type, ) diff --git a/captum/influence/_core/tracincp_fast_rand_proj.py b/captum/influence/_core/tracincp_fast_rand_proj.py index 8c679266e4..6d430fa8f6 100644 --- a/captum/influence/_core/tracincp_fast_rand_proj.py +++ b/captum/influence/_core/tracincp_fast_rand_proj.py @@ -5,7 +5,18 @@ import threading import warnings from collections import defaultdict -from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, +) import torch from captum._utils.common import _get_module_from_name, _sort_key_list @@ -418,10 +429,13 @@ def _influence( # type: ignore[override] """ train_dataloader = self.train_dataloader + train_dataloader_iterable: Union[DataLoader, Iterable[Tuple[object, ...]]] = ( + train_dataloader + ) if show_progress: - train_dataloader = progress( - train_dataloader, + train_dataloader_iterable = progress( + cast(Iterable[Tuple[object, ...]], train_dataloader), desc=( f"Using {self.get_name()} to compute " "influence for training batches" @@ -432,7 +446,7 @@ def _influence( # type: ignore[override] return torch.cat( [ self._influence_batch_tracincp_fast(inputs, batch) - for batch in train_dataloader + for batch in train_dataloader_iterable ], dim=1, ) diff --git a/captum/influence/_utils/common.py b/captum/influence/_utils/common.py index ba3ba0f85e..8a966f9e21 100644 --- a/captum/influence/_utils/common.py +++ b/captum/influence/_utils/common.py @@ -6,6 +6,7 @@ from typing import ( Any, Callable, + cast, Dict, Iterable, List, @@ -273,18 +274,19 @@ def _get_k_most_influential_helper( # if show_progress, create progress bar total: Optional[int] = None + data_iterator: Union[Iterable[object], DataLoader] = influence_src_dataloader if show_progress: try: total = len(influence_src_dataloader) except AttributeError: pass - influence_src_dataloader = progress( - influence_src_dataloader, + data_iterator = progress( + cast(Iterable[object], influence_src_dataloader), desc=desc, total=total, ) - for batch in influence_src_dataloader: + for batch in data_iterator: # calculate tracin_scores for the batch batch_tracin_scores = influence_batch_fn(inputs, batch) @@ -406,6 +408,7 @@ def _self_influence_by_batches_helper( """ # If `inputs_dataset` is not a `DataLoader`, turn it into one. inputs_dataset = _format_inputs_dataset(inputs_dataset) + inputs_dataset_iterator: Union[Iterable[object], DataLoader] = inputs_dataset # If `show_progress` is true, create a progress bar that keeps track of how # many batches have been processed @@ -425,7 +428,7 @@ def _self_influence_by_batches_helper( stacklevel=1, ) # then create the progress bar - inputs_dataset = progress( + inputs_dataset_iterator = progress( inputs_dataset, desc=f"Using {instance_name} to compute self influence. Processing batch", total=inputs_dataset_len, @@ -440,7 +443,7 @@ def _self_influence_by_batches_helper( return torch.cat( [ self_influence_batch_fn(batch, show_progress=False) - for batch in inputs_dataset + for batch in inputs_dataset_iterator ] )