Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Nov 4, 2024
1 parent 34b7be3 commit bcaeb5f
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 19 deletions.
7 changes: 7 additions & 0 deletions lightning_ir/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Lightning IR base module.
This module provides the main classes and functions for the Lightning IR library, including
factories, configurations, models, modules, and tokenizers.
"""

from .class_factory import (
LightningIRClassFactory,
LightningIRConfigClassFactory,
Expand Down
21 changes: 14 additions & 7 deletions lightning_ir/base/class_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
Class factory module for Lightning IR.
This module provides factory classes for creating various components of the Lightning IR library
by extending Hugging Face Transformers classes.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
Expand All @@ -22,7 +29,7 @@


class LightningIRClassFactory(ABC):
"""Base class for creating derived LightningIR classes from HuggingFace classes."""
"""Base class for creating derived Lightning IR classes from HuggingFace classes."""

def __init__(self, MixinConfig: Type[LightningIRConfig]) -> None:
"""Creates a new LightningIRClassFactory.
Expand All @@ -48,7 +55,7 @@ def get_backbone_config(model_name_or_path: str | Path) -> Type[PretrainedConfig

@staticmethod
def get_lightning_ir_config(model_name_or_path: str | Path) -> Type[LightningIRConfig] | None:
"""Grabs the LightningIR configuration class from a checkpoint of a pretrained Lightning IR model.
"""Grabs the Lightning IR configuration class from a checkpoint of a pretrained Lightning IR model.
:param model_name_or_path: Path to the model or its name
:type model_name_or_path: str | Path
Expand Down Expand Up @@ -91,27 +98,27 @@ def get_lightning_ir_model_type(model_name_or_path: str | Path) -> str | None:

@property
def cc_lir_model_type(self) -> str:
"""Camel case model type of the LightningIR model."""
"""Camel case model type of the Lightning IR model."""
return "".join(s.title() for s in self.MixinConfig.model_type.split("-"))

@abstractmethod
def from_pretrained(self, model_name_or_path: str | Path, *args, **kwargs) -> Any:
"""Loads a derived LightningIR class from a pretrained HuggingFace model. Must be implemented by subclasses.
"""Loads a derived Lightning IR class from a pretrained HuggingFace model. Must be implemented by subclasses.
:param model_name_or_path: Path to the model or its name
:type model_name_or_path: str | Path
:return: Derived LightningIR class
:return: Derived Lightning IR class
:rtype: Any
"""
...

@abstractmethod
def from_backbone_class(self, BackboneClass: Type) -> Type:
"""Creates a derived LightningIR class from a backbone HuggingFace class. Must be implemented by subclasses.
"""Creates a derived Lightning IR class from a backbone HuggingFace class. Must be implemented by subclasses.
:param BackboneClass: Backbone class
:type BackboneClass: Type
:return: Derived LightningIR class
:return: Derived Lightning IR class
:rtype: Type
"""
...
Expand Down
24 changes: 22 additions & 2 deletions lightning_ir/base/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
Base configuration class for Lightning IR models.
This module defines the configuration class `LightningIRConfig` which is used to instantiate
a Lightning IR model. The configuration class acts as a mixin for the `transformers.PretrainedConfig`
class from the Hugging Face Transformers library.
"""

from pathlib import Path
from typing import Any, Dict, Set

Expand All @@ -8,7 +16,7 @@


class LightningIRConfig:
"""The configuration class to instantiate a LightningIR model. Acts as a mixin for the
"""The configuration class to instantiate a Lightning IR model. Acts as a mixin for the
transformers.PretrainedConfig_ class.
.. _transformers.PretrainedConfig: \
Expand Down Expand Up @@ -58,7 +66,7 @@ def to_dict(self) -> Dict[str, Any]:
model type.
.. _transformers.PretrainedConfig.to_dict: \
https://huggingface.co/docs/transformers/main_classes/configuration.html#transformers.PretrainedConfig.to_dict
https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.to_dict
:return: Configuration dictionary
:rtype: Dict[str, Any]
Expand All @@ -73,6 +81,18 @@ def to_dict(self) -> Dict[str, Any]:

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str | Path, *args, **kwargs) -> "LightningIRConfig":
"""Loads the configuration from a pretrained model. Wraps the transformers.PretrainedConfig.from_pretrained_
.. _transformers.PretrainedConfig.from_pretrained: \
https://huggingface.co/docs/transformers/en/main_classes/configuration#transformers.PretrainedConfig.from_pretrained
:param pretrained_model_name_or_path: Pretrained model name or path
:type pretrained_model_name_or_path: str | Path
:raises ValueError: If `pre_trained_model_name_or_path` is not a Lightning IR model and no
:py:class:`LightningIRConfig` is passed
:return: Derived LightningIRConfig class
:rtype: LightningIRConfig
"""
if cls is LightningIRConfig or all(issubclass(base, LightningIRConfig) for base in cls.__bases__):
config = None
if pretrained_model_name_or_path in CHECKPOINT_MAPPING:
Expand Down
13 changes: 13 additions & 0 deletions lightning_ir/base/external_model_hub.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
"""
Model hub for loading external models.
This module contains mappings and callbacks for external model checkpoints used in the Lightning IR library.
Attributes:
CHECKPOINT_MAPPING (Dict[str, LightningIRConfig]): Mapping of model checkpoint identifiers to their configurations.
STATE_DICT_KEY_MAPPING (Dict[str, List[Tuple[str | None, str]]]): Mapping of state dictionary keys for model
checkpoints.
POST_LOAD_CALLBACKS (Dict[str, Callable[[LightningIRModel], LightningIRModel]]): Callbacks to be executed after
loading a model.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Dict, List, Tuple
Expand Down
13 changes: 10 additions & 3 deletions lightning_ir/base/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import warnings
"""
Model module for Lightning IR.
This module contains the main model class and output class for the Lightning IR library.
"""

from collections import defaultdict
from dataclasses import dataclass
from functools import partial, wraps
Expand All @@ -17,7 +22,7 @@

@dataclass
class LightningIROutput(ModelOutput):
"""Base class for the output of the LightningIR model. It is a subclass of transformers.ModelOutput_.
"""Base class for the output of the Lightning IR model. It is a subclass of transformers.ModelOutput_.
.. _transformers.ModelOutput: https://huggingface.co/transformers/main_classes/output.html#transformers.ModelOutput
Expand All @@ -29,7 +34,7 @@ class LightningIROutput(ModelOutput):


class LightningIRModel:
"""Base class for LightningIR models. Derived classes implement the forward method for handling query
"""Base class for Lightning IR models. Derived classes implement the forward method for handling query
and document embeddings. It acts as mixin for a transformers.PreTrainedModel_ backbone model.
.. _transformers.PreTrainedModel: \
Expand Down Expand Up @@ -211,6 +216,7 @@ def from_pretrained(cls, model_name_or_path: str | Path, *args, **kwargs) -> "Li
def _cat_outputs(
outputs: Sequence[Mapping] | Sequence[torch.Tensor] | Sequence[None], OutputClass: Type[T] | None
) -> torch.Tensor | T | None:
"""Helper method to concatenate outputs of the model."""
if len(outputs) == 1:
return outputs[0]
if len(outputs) == 0 or outputs[0] is None or OutputClass is None:
Expand All @@ -229,6 +235,7 @@ def _cat_outputs(
def _batch_encoding(
func: Callable[[LightningIRModel, BatchEncoding, ...], Any]
) -> Callable[[LightningIRModel, BatchEncoding, ...], Any]:
"""Decorator to enable sub-batching for models that support it."""

@wraps(func)
def wrapper(self, encoding: BatchEncoding, *args, **kwargs) -> Any:
Expand Down
19 changes: 14 additions & 5 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""LightningModule for Lightning IR.
This module contains the main module class deriving from a LightningModule_.
.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
"""

from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple, Type
Expand All @@ -15,9 +22,11 @@


class LightningIRModule(LightningModule):
"""LightningIRModule base class. LightningIRModules contain a LightningIRModel and a LightningIRTokenizer and
implements the training, validation, and testing steps for the model. Derived classes must implement the forward
method for the model.
"""LightningIRModule base class. It dervies from a LightningModule_. LightningIRModules contain a
LightningIRModel and a LightningIRTokenizer and implements the training, validation, and testing steps for the
model. Derived classes must implement the forward method for the model.
.. _LightningModule: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html
"""

def __init__(
Expand All @@ -32,11 +41,11 @@ def __init__(
.. _ir-measures: https://ir-measur.es/en/latest/index.html
:param model_name_or_path: Name or path of backbone model or fine-tuned LightningIR model, defaults to None
:param model_name_or_path: Name or path of backbone model or fine-tuned Lightning IR model, defaults to None
:type model_name_or_path: str | None, optional
:param config: LightningIRConfig to apply when loading from backbone model, defaults to None
:type config: LightningIRConfig | None, optional
:param model: Already instantiated LightningIR model, defaults to None
:param model: Already instantiated Lightning IR model, defaults to None
:type model: LightningIRModel | None, optional
:param loss_functions: Loss functions to apply during fine-tuning, optional loss weights can be provided per
loss function, defaults to None
Expand Down
9 changes: 7 additions & 2 deletions lightning_ir/base/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import warnings
"""
Tokenizer module for Lightning IR.
This module contains the main tokenizer class for the Lightning IR library.
"""

from typing import Dict, Sequence, Type

from transformers import TOKENIZER_MAPPING, BatchEncoding
Expand All @@ -9,7 +14,7 @@


class LightningIRTokenizer:
"""Base class for LightningIR tokenizers. Derived classes implement the tokenize method for handling query
"""Base class for Lightning IR tokenizers. Derived classes implement the tokenize method for handling query
and document tokenization. It acts as mixin for a transformers.PreTrainedTokenizer_ backbone tokenizer.
.. _transformers.PreTrainedTokenizer: \
Expand Down
4 changes: 4 additions & 0 deletions lightning_ir/base/validation_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""Validation utilities module for Lightning IR.
This module contains utility functions for validation and evaluation of Lightning IR models."""

from typing import Dict, Sequence

import ir_measures
Expand Down

0 comments on commit bcaeb5f

Please sign in to comment.