-
-
Notifications
You must be signed in to change notification settings - Fork 91
Custom models
Diart is compatible with models that are not trained with pyannote.audio. However, an additional effort is required from the user so that the 3rd party model satisfies the expected interface. This allows diart to run without knowing how models actually work internally.
A segmentation model must ingest a batch of audio chunks and return the corresponding per-speaker activity probabilities across time. It must also define the expected sample rate and duration of its inputs so that the pipeline knows how to format the audio stream.
from diart.models import SegmentationModel
class MySegmentationModel(SegmentationModel):
def __init__(self):
self.my_pretrained_model = load("my_segmentation.ckpt")
def get_sample_rate(self) -> int:
return 16000
def get_duration(self) -> float:
return 2 # seconds
def __call__(self, waveform: torch.Tensor) -> torch.Tensor:
# waveform has shape (batch, channels, samples)
# ... operations to adapt the input to this specific model (e.g. converting to TensorFlow)
output = self.my_pretrained_model(waveform)
# ... operations to adapt the output to a torch.Tensor of shape (batch, frames, speakers)
return output
A speaker embedding model must ingest a batch of audio chunks and output a batch of speaker embeddings. Optional weights are usually provided to inform the model where it should focus its attention.
from diart.models import EmbeddingModel
class MyEmbeddingModel(EmbeddingModel):
def __init__(self):
super().__init__()
self.my_pretrained_model = load("my_embedding_model.ckpt")
def __call__(
self,
waveform: torch.Tensor,
weights: Optional[torch.Tensor] = None
) -> torch.Tensor:
# waveform has shape (batch, channels, samples)
# weights have shape (batch, frames)
# The output should have shape (batch, embedding_dim)
return self.my_pretrained_model(waveform, weights)
Models can be easily replaced from the configuration object PipelineConfig
:
from diart.pipelines import PipelineConfig, OnlineSpeakerDiarization
config = PipelineConfig(segmentation=MySegmentationModel(), embedding=MyEmbeddingModel())
diarization = OnlineSpeakerDiarization(config)
...