Skip to content

Commit

Permalink
Merge pull request #69 from aimclub/add-tests
Browse files Browse the repository at this point in the history
Add tests
  • Loading branch information
MedAI-Lab-ITMO authored Dec 21, 2024
2 parents 4291f54 + 4438c51 commit a3d6432
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 41 deletions.
73 changes: 45 additions & 28 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,57 @@ def compare_values(expected, got, message_header=None):
), f"{_form_message_header(message_header)}: expected {expected}, got {got}"


def create_testing_data():
N = 20
dim = 256
data = torch.randn((N, dim))
return N, dim, data
def create_testing_data(architecture='fcn'):
architecture = architecture.lower()
if architecture == 'fcn':
return torch.randn((20, 256))
elif architecture == 'cnn':
return torch.randn((20, 3, 32, 32))
else:
raise Exception(f'Unsupported architecture type: {architecture}')


def create_testing_model(num_classes=10):
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Linear(256, 128)),
("second_layer", nn.Linear(128, 64)),
("third_layer", nn.Linear(64, num_classes)),
],
),
)
def create_testing_model(architecture='fcn', num_classes=10):
architecture = architecture.lower()
if architecture == 'fcn':
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Linear(256, 128)),
("second_layer", nn.Linear(128, 64)),
("third_layer", nn.Linear(64, num_classes)),
],
),
)
elif architecture == 'cnn':
return nn.Sequential(
OrderedDict(
[
("first_layer", nn.Conv2d(in_channels=3, out_channels=10, kernel_size=7)),
("second_layer", nn.Conv2d(in_channels=10, out_channels=20, kernel_size=7)),
("avgpool", nn.AdaptiveAvgPool2d(1)),
("flatten", nn.Flatten()),
("fc", nn.Linear(20, num_classes)),
],
),
)
elif architecture == 'rnn':
return nn.Sequential(
OrderedDict(
[
('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
('extract', ExtractTensor()),
('second_layer', nn.Linear(128, 64)),
('third_layer', nn.Linear(64, num_classes)),
],
),
)
else:
raise Exception(f'Unsupported architecture type: {architecture}')


class ExtractTensor(nn.Module):
def forward(self, x):
tensor, _ = x
x = x.to(torch.float32)
return tensor[:, :]


def create_testing_model_lstm(num_classes=10):
return nn.Sequential(
OrderedDict(
[
('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
('extract', ExtractTensor()),
('second_layer', nn.Linear(128, 64)),
('third_layer', nn.Linear(64, num_classes)),
],
),
)
68 changes: 55 additions & 13 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ def test_check_random_input():


def _check_reduce_dim(mode):
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
reduced_data = viz_api.reduce_dim(data, mode)
utils.compare_values(np.ndarray, type(reduced_data), "Wrong result type")
utils.compare_values((N, 2), reduced_data.shape, "Wrong result shape")
utils.compare_values((len(data), 2), reduced_data.shape, "Wrong result shape")


def test_reduce_dim_umap():
Expand All @@ -31,8 +31,8 @@ def test_reduce_dim_pca():
_check_reduce_dim("pca")


def test_visualization():
N, dim, data = utils.create_testing_data()
def test_visualization_fcn():
data = utils.create_testing_data()
model = utils.create_testing_model()
layers = ["second_layer", "third_layer"]
res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)
Expand All @@ -52,10 +52,31 @@ def test_visualization():
)


def test_visualization_cnn():
data = utils.create_testing_data(architecture='cnn')
model = utils.create_testing_model(architecture='cnn')
layers = ["first_layer", "second_layer", "avgpool", "flatten", "fc"]
res = viz_api.visualize_layer_manifolds(model, "umap", data, layers=layers)

utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(6, len(res), "Wrong dictionary length")
utils.compare_values(
set(["input"] + layers),
set(res.keys()),
"Wrong dictionary keys",
)
for key, plot in res.items():
utils.compare_values(
matplotlib.figure.Figure,
type(plot),
f"Wrong value type for key {key}",
)


def test_embed_visualization():
data = torch.randn((20, 1, 256))
labels = torch.randn((20))
model = utils.create_testing_model_lstm()
model = utils.create_testing_model('rnn')
layers = ["second_layer", "third_layer"]
res = viz_api.visualize_recurrent_layer_manifolds(model, "umap",
data, layers=layers, labels=labels)
Expand All @@ -74,16 +95,16 @@ def test_embed_visualization():
)


def _test_bayes_prediction(mode: str):
def _test_bayes_prediction(mode: str, architecture='fcn'):
params = {
"basic": dict(mode="basic", p=0.5),
"beta": dict(mode="beta", a=0.9, b=0.2),
"gauss": dict(sigma=1e-2),
}

N, dim, data = utils.create_testing_data()
data = utils.create_testing_data(architecture=architecture)
num_classes = 17
model = utils.create_testing_model(num_classes=num_classes)
model = utils.create_testing_model(architecture=architecture, num_classes=num_classes)
n_iter = 7
if mode != 'gauss':
res = bayes_api.DropoutBayesianWrapper(model, **(params[mode])).predict(data, n_iter=n_iter)
Expand All @@ -94,6 +115,7 @@ def _test_bayes_prediction(mode: str):
utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(2, len(res), "Wrong dictionary length")
utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys")
N = len(data)
utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape")
utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std")

Expand All @@ -110,14 +132,18 @@ def test_gauss_bayes_wrapper():
_test_bayes_prediction("gauss")


def test_bayes_wrapper_cnn():
_test_bayes_prediction("basic", architecture='cnn')


def test_data_barcode():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
res = topology_api.get_data_barcode(data, "standard", "3")
utils.compare_values(dict, type(res), "Wrong result type")


def test_nn_barcodes():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
model = utils.create_testing_model()
layers = ["second_layer", "third_layer"]
res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
Expand All @@ -132,15 +158,31 @@ def test_nn_barcodes():
)


def test_nn_barcodes_cnn():
data = utils.create_testing_data(architecture='cnn')
model = utils.create_testing_model(architecture='cnn')
layers = ["second_layer", "flatten"]
res = topology_api.get_nn_barcodes(model, data, layers, "standard", "3")
utils.compare_values(dict, type(res), "Wrong result type")
utils.compare_values(2, len(res), "Wrong dictionary length")
utils.compare_values(set(layers), set(res.keys()), "Wrong dictionary keys")
for layer, barcode in res.items():
utils.compare_values(
dict,
type(barcode),
f"Wrong result type for key {layer}",
)


def test_barcode_plot():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
barcode = topology_api.get_data_barcode(data, "standard", "3")
plot = topology_api.plot_barcode(barcode)
utils.compare_values(matplotlib.figure.Figure, type(plot), "Wrong result type")


def test_barcode_evaluate_all_metrics():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
barcode = topology_api.get_data_barcode(data, "standard", "3")
result = topology_api.evaluate_barcode(barcode)
utils.compare_values(dict, type(result), "Wrong result type")
Expand All @@ -166,7 +208,7 @@ def test_barcode_evaluate_all_metrics():


def test_barcode_evaluate_one_metric():
N, dim, data = utils.create_testing_data()
data = utils.create_testing_data()
barcode = topology_api.get_data_barcode(data, "standard", "3")
result = topology_api.evaluate_barcode(barcode, metric_name="mean_length")
utils.compare_values(float, type(result), "Wrong result type")

0 comments on commit a3d6432

Please sign in to comment.