Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Img-Diff ops. #550

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,21 @@ process:
cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'.
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
radius: 2 # radius of blur kernel
- image_segment_mapper: # perform segment-anything on images and return the bounding boxes.
imgsz: 1024 # image resolution after image resizing
conf: 0.05 # confidence score threshold
iou: 0.5 # IoU (Intersection over Union) score threshold
mem_required: '800MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrain the maximum number of processes that can be launched
- image_tagging_mapper: # Mapper to generate image tags.
tag_field_name: 'image_tags' # the field name to store the tags. It's "image_tags" in default.
mem_required: '9GB'
- mllm_mapper: # use MLLMs for visual question answering tasks
hf_model: 'liuhaotian/llava-v1.6-vicuna-7b' # model name of the MLLM on huggingface
max_new_tokens: 256 # the maximum number of new tokens generated by the model
temperature: 0.2 # used to control the randomness of the generated text
top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p
num_beams: 1 # the larger the beam search size, the higher the quality of the generated text
mem_required: '32GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrain the maximum number of processes that can be launched
- nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library
sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently.
aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated.
Expand Down Expand Up @@ -414,6 +426,22 @@ process:
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove
- sdxl_prompt2prompt_mapper: # use the generative model SDXL and image editing technique Prompt-to-Prompt to generate pairs of similar images.
hf_diffusion: 'stabilityai/stable-diffusion-xl-base-1.0' # model name of the SDXL model on huggingface
num_inference_steps: 50 # the larger the value, the better the image generation quality
guidance_scale: 7.5 # a higher guidance_scale value encourages the model to generate images closely linked to the text prompt at the expense of lower image quality
text_key_second: None # used to store the first caption in the caption pair
text_key_third: None # used to store the second caption in the caption pair
mem_required: '38GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrain the maximum number of processes that can be launched
- sentence_augmentation_mapper: # augment sentences using LLMs.
hf_model: 'Qwen/Qwen2-7B-Instruct' # model name of the LLM on huggingface
system_prompt: None # system prompt
task_sentence: None # the instruction for the current task
max_new_tokens: 256 # the maximum number of new tokens generated by the model
temperature: 0.2 # used to control the randomness of generated text
top_p: None # randomly select the next word from the group of words whose cumulative probability reaches p
num_beams: 1 # the larger the beam search size, the higher the quality of the generated text
mem_required: '31GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrain the maximum number of processes that can be launched
- sentence_split_mapper: # split text to multiple sentences and join them with '\n'
lang: 'en' # split text in what language
- text_chunk_mapper: # Split input text to chunks.
Expand Down Expand Up @@ -658,6 +686,13 @@ process:
- text_length_filter: # filter text with length out of specific range
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
- text_pair_similarity_filter: # filter samples according to the similarity score between the text pair.
hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface
min_score: 0.1 # the min similarity score of filter range
max_score: 1.0 # the max similarity score of filter range
text_key_second: None # used to store the other sentence in the text pair
any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition
mem_required: '1500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrain the maximum number of processes that can be launched
- token_num_filter: # filter text with total token number out of specific range
hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer
min_num: 10 # the min number of filter range
Expand Down
13 changes: 7 additions & 6 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .text_action_filter import TextActionFilter
from .text_entity_dependency_filter import TextEntityDependencyFilter
from .text_length_filter import TextLengthFilter
from .text_pair_similarity_filter import TextPairSimilarityFilter
from .token_num_filter import TokenNumFilter
from .video_aesthetics_filter import VideoAestheticsFilter
from .video_aspect_ratio_filter import VideoAspectRatioFilter
Expand Down Expand Up @@ -56,12 +57,12 @@
'SpecialCharactersFilter', 'SpecifiedFieldFilter',
'SpecifiedNumericFieldFilter', 'StopWordsFilter', 'SuffixFilter',
'TextActionFilter', 'TextEntityDependencyFilter', 'TextLengthFilter',
'TokenNumFilter', 'VideoAestheticsFilter', 'VideoAspectRatioFilter',
'VideoDurationFilter', 'VideoFramesTextSimilarityFilter',
'VideoMotionScoreFilter', 'VideoMotionScoreRaftFilter', 'VideoNSFWFilter',
'VideoOcrAreaRatioFilter', 'VideoResolutionFilter',
'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter',
'WordRepetitionFilter', 'WordsNumFilter'
'TextPairSimilarityFilter', 'TokenNumFilter', 'VideoAestheticsFilter',
'VideoAspectRatioFilter', 'VideoDurationFilter',
'VideoFramesTextSimilarityFilter', 'VideoMotionScoreFilter',
'VideoMotionScoreRaftFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter',
'VideoResolutionFilter', 'VideoTaggingFromFramesFilter',
'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter'
]

NON_STATS_FILTERS = [
Expand Down
117 changes: 117 additions & 0 deletions data_juicer/ops/filter/text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import logging

import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.ops.base_op import OPERATORS, Filter
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model

torch = LazyLoader('torch', 'torch')
transformers = LazyLoader('transformers', 'transformers')
torch.set_num_threads(1)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

OP_NAME = 'text_pair_similarity_filter'


@OPERATORS.register_module(OP_NAME)
class TextPairSimilarityFilter(Filter):
"""Filter to keep text pairs with similarities between texts
within a specific range."""

_accelerator = 'cuda'

def __init__(self,
hf_clip='openai/clip-vit-base-patch32',
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
text_key_second=None,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.

:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param text_key_second: used to store the other sentence
in the text pair.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
kwargs.setdefault('mem_required', '1500MB')
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)
self.text_key_second = text_key_second

def compute_stats_single(self, sample, rank=None, context=False):

# check if it's computed already
if StatsKeys.text_pair_similarity in sample[Fields.stats]:
return sample

# there is no target text
if self.text_key_second is None:
logger.error('This OP (text_pair_similarity_filter) requires \
processing multiple fields, and you need to specify \
valid `text_key_second`')

# there is no text in this sample
if (self.text_key not in sample or len(sample[self.text_key]) == 0
or self.text_key_second not in sample
or len(sample[self.text_key_second]) == 0):
sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array(
[], dtype=np.float64)
return sample

model, processor = get_model(self.model_key, rank, self.use_cuda())

text1 = sample[self.text_key]
text2 = sample[self.text_key_second]

text_tensors = processor([text1, text2],
padding=True,
return_tensors='pt').to(model.device)
text_features = model.get_text_features(**text_tensors)

similarity = torch.cosine_similarity(text_features[0],
text_features[1],
dim=0)
sample[Fields.stats][StatsKeys.text_pair_similarity] = [similarity]

return sample

def process_single(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.text_pair_similarity]
if len(similarity) <= 0:
return True

keep_bools = np.array([
self.min_score <= sim_value <= self.max_score
for sim_value in similarity
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
43 changes: 24 additions & 19 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from .image_captioning_mapper import ImageCaptioningMapper
from .image_diffusion_mapper import ImageDiffusionMapper
from .image_face_blur_mapper import ImageFaceBlurMapper
from .image_segment_mapper import ImageSegmentMapper
from .image_tagging_mapper import ImageTaggingMapper
from .mllm_mapper import MllmMapper
from .nlpaug_en_mapper import NlpaugEnMapper
from .nlpcda_zh_mapper import NlpcdaZhMapper
from .optimize_qa_mapper import OptimizeQAMapper
Expand All @@ -53,6 +55,8 @@
from .remove_words_with_incorrect_substrings_mapper import \
RemoveWordsWithIncorrectSubstringsMapper
from .replace_content_mapper import ReplaceContentMapper
from .sdxl_prompt2prompt_mapper import SDXLPrompt2PromptMapper
from .sentence_augmentation_mapper import SentenceAugmentationMapper
from .sentence_split_mapper import SentenceSplitMapper
from .text_chunk_mapper import TextChunkMapper
from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper
Expand Down Expand Up @@ -87,23 +91,24 @@
'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper',
'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper',
'ImageCaptioningMapper', 'ImageDiffusionMapper', 'ImageFaceBlurMapper',
'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper',
'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper',
'PairPreferenceMapper', 'PunctuationNormalizationMapper',
'PythonFileMapper', 'PythonLambdaMapper', 'QuerySentimentDetectionMapper',
'QueryIntentDetectionMapper', 'QueryTopicDetectionMapper',
'RelationIdentityMapper', 'RemoveBibliographyMapper',
'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper',
'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper',
'RemoveSpecificCharsMapper', 'RemoveTableTextMapper',
'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper',
'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper',
'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper',
'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper',
'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
'WhitespaceNormalizationMapper'
'ImageSegmentMapper', 'ImageTaggingMapper', 'MllmMapper', 'NlpaugEnMapper',
'NlpcdaZhMapper', 'OptimizeQAMapper', 'OptimizeQueryMapper',
'OptimizeResponseMapper', 'PairPreferenceMapper',
'PunctuationNormalizationMapper', 'PythonFileMapper', 'PythonLambdaMapper',
'QuerySentimentDetectionMapper', 'QueryIntentDetectionMapper',
'QueryTopicDetectionMapper', 'RelationIdentityMapper',
'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper',
'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper',
'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper',
'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper',
'ReplaceContentMapper', 'SDXLPrompt2PromptMapper',
'SentenceAugmentationMapper', 'SentenceSplitMapper', 'TextChunkMapper',
'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
]
Loading
Loading