Skip to content

Commit

Permalink
Add fix to test functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iainmon committed Jan 17, 2025
1 parent 42fb705 commit 8ddecf8
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 3 deletions.
6 changes: 3 additions & 3 deletions scripts/chai.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_summary(model,global_name,parent_name=None):
}
return d

def dump_model_parameters(model,path_prefix,model_name,with_json=True,verbose=True):
def dump_model_parameters(model,path_prefix,model_name,with_json=True,verbose=True,dtype=None):
Path(path_prefix).mkdir(exist_ok=True)
for param_tensor in model.state_dict():
if verbose: print("Serializing ", param_tensor)
Expand All @@ -82,8 +82,8 @@ def dump_model_parameters(model,path_prefix,model_name,with_json=True,verbose=Tr
f.write(json.dumps(get_summary(model,model_name),indent=2))


def chai_dump(self,path_prefix,model_name,with_json=True,verbose=True):
return dump_model_parameters(self,path_prefix,model_name,with_json,verbose)
def chai_dump(self,path_prefix,model_name,with_json=True,verbose=True,dtype=None):
return dump_model_parameters(self,path_prefix,model_name,with_json,verbose,dtype)

torch.nn.Module.chai_dump = chai_dump

Expand Down
13 changes: 13 additions & 0 deletions test/moduleTests/layerTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,18 @@ def test_dropout():
y.chai_save('./savedTensors', 'dropoutAnswer', with_json=False, verbose=False, dtype=torch.float64)


def test_all():
test_dummy()
test_load()
test_linear()
test_conv2d()
test_maxpool2d()
test_adaptiveavgpool2d()
test_flatten()
test_relu()
test_softmax()
test_dropout()

def main():
torch.manual_seed(5)

Expand All @@ -161,6 +173,7 @@ def main():
parser.add_argument('--relu', action='store_const', const=test_relu, help='Run test_relu')
parser.add_argument('--softmax', action='store_const', const=test_softmax, help='Run test_softmax')
parser.add_argument('--dropout', action='store_const', const=test_dropout, help='Run test_dropout')
parser.add_argument('--all', action='store_const', const=test_all, help='Run all tests')

args = parser.parse_args()

Expand Down
19 changes: 19 additions & 0 deletions test/moduleTests/savedTensors/dummyAnswer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"rank": 1,
"shape": [
10
],
"dtype": "float32",
"data": [
-0.6179778575897217,
0.29207125306129456,
0.5119034051895142,
-0.9148514866828918,
0.9341246485710144,
-0.13598701357841492,
0.26149165630340576,
0.5164260864257812,
-0.47369763255119324,
0.08458160609006882
]
}
21 changes: 21 additions & 0 deletions test/moduleTests/savedTensors/dummyModel/specification.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"layerType": "Dummy",
"attributes": {
"training": "true"
},
"subModules": {
"model": {
"layerType": "Linear",
"attributes": {
"training": "true",
"in_features": "100",
"out_features": "10"
},
"subModules": {},
"subModuleOrder": []
}
},
"subModuleOrder": [
"model"
]
}
19 changes: 19 additions & 0 deletions test/moduleTests/savedTensors/loadInput.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"rank": 1,
"shape": [
10
],
"dtype": "float64",
"data": [
0.0,
1.0,
2.0,
3.0,
4.0,
5.0,
6.0,
7.0,
8.0,
9.0
]
}

0 comments on commit 8ddecf8

Please sign in to comment.