Skip to content

Commit

Permalink
style: format code
Browse files Browse the repository at this point in the history
  • Loading branch information
34j committed Mar 16, 2024
1 parent ca95b8f commit 0b9426f
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions torch_frame/utils/skorch.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
import skorch.utils

# TODO: make it more safe
old_to_tensor = skorch.utils.to_tensor

def to_tensor(X, device, accept_sparse=False):
if isinstance(X, TensorFrame):
return X
return old_to_tensor(X, device, accept_sparse)

skorch.utils.to_tensor = to_tensor
import importlib
importlib.reload(skorch.net)

from typing import Any

import pandas as pd
import skorch.utils
import torch
import torch.nn as nn
from numpy.typing import ArrayLike
from pandas import DataFrame
from skorch import NeuralNet, NeuralNetClassifier
from skorch.dataset import Dataset as SkorchDataset
from skorch import NeuralNet
from torch import Tensor

import torch_frame
Expand All @@ -29,20 +14,34 @@ def to_tensor(X, device, accept_sparse=False):
TextEmbedderConfig,
TextTokenizerConfig,
)
from torch_frame.data.dataset import DataFrameToTensorFrameConverter, Dataset
from torch_frame.data.dataset import Dataset
from torch_frame.data.loader import DataLoader
from torch_frame.data.tensor_frame import TensorFrame
from torch_frame.typing import IndexSelectType
from torch_frame.utils import infer_df_stype

# TODO: make it more safe
old_to_tensor = skorch.utils.to_tensor


def to_tensor(X, device, accept_sparse=False):
if isinstance(X, TensorFrame):
return X
return old_to_tensor(X, device, accept_sparse)


skorch.utils.to_tensor = to_tensor

importlib.reload(skorch.net)


class NeuralNetPytorchFrameDataLoader(DataLoader):
def __init__(self, dataset: Dataset | TensorFrame, *args,
device: torch.device, **kwargs):
super().__init__(dataset, *args, **kwargs)
self.device = device

def collate_fn(
def collate_fn( # type: ignore
self, index: IndexSelectType) -> tuple[TensorFrame, Tensor | None]:
index = torch.tensor(index)
res = super().collate_fn(index).to(self.device)
Expand Down

0 comments on commit 0b9426f

Please sign in to comment.