Skip to content
This repository has been archived by the owner on Feb 3, 2023. It is now read-only.

Commit

Permalink
feature function laoding upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
j6mes committed Apr 30, 2018
1 parent 89415b6 commit add471f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 22 deletions.
29 changes: 20 additions & 9 deletions src/common/features/feature_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@ def __init__(self,model_name, features=list(), label_name="label",base_path="fea
self.logger = LogHelper.get_logger(Features.__name__)
self.mname = model_name

def load(self,train,dev=None,test=None):
train_fs = []
dev_fs = []
test_fs = []

def check_needs_generate(self,train,dev,test):
for ff in self.feature_functions:
ffpath = os.path.join(self.base_path, ff.get_name())

Expand All @@ -31,12 +28,31 @@ def load(self,train,dev=None,test=None):
or (test is not None and not os.path.exists(os.path.join(ffpath, "test"))) or \
os.getenv("GENERATE","").lower() in ["y", "1", "t", "yes"]:

return True

return False

def load(self,train,dev=None,test=None):
train_fs = []
dev_fs = []
test_fs = []

if self.check_needs_generate(train,dev,test):
self.inform(train,dev,test)
else:
try:
self.load_vocab(self.mname)
except:
self.logger.info("Could not load vocab. Regenerating")
self.inform(train,dev,test)


for ff in self.feature_functions:
train_fs.append(self.generate_or_load(ff, train, "train"))
dev_fs.append(self.generate_or_load(ff, dev, "dev"))
test_fs.append(self.generate_or_load(ff, test, "test"))

self.save_vocab(self.mname)

return self.out(train_fs,train), self.out(dev_fs,dev), self.out(test_fs,test)

Expand All @@ -53,17 +69,12 @@ def generate_or_load(self,feature,dataset,name):
self.logger.info("Loading Features for {0}.{1}".format(feature, name))
with open(os.path.join(ffpath, name), "rb") as f:
features = pickle.load(f)

feature.load_vocab(self.mname)

else:
self.logger.info("Generating Features for {0}.{1}".format(feature,name))
features = feature.lookup(dataset.data)

with open(os.path.join(ffpath, name), "wb+") as f:
pickle.dump(features, f)

feature.save_vocab(self.mname)
return features

return None
Expand Down
30 changes: 22 additions & 8 deletions src/rte/riedel/fever_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,28 @@ def save(self,mname):
def load(self,mname):
self.logger.info("Loading TFIDF features from disk")

with open("features/{0}-bowv".format(mname), "rb") as f:
self.bow_vectorizer = pickle.load(f)
with open("features/{0}-bow".format(mname), "rb") as f:
self.bow = pickle.load(f)
with open("features/{0}-tfidf".format(mname), "rb") as f:
self.tfidf_vectorizer = pickle.load(f)
with open("features/{0}-tfreq".format(mname), "rb") as f:
self.tfreq_vectorizer = pickle.load(f)
try:
with open("features/{0}-bowv".format(mname), "rb") as f:
bow_vectorizer = pickle.load(f)
with open("features/{0}-bow".format(mname), "rb") as f:
bow = pickle.load(f)
with open("features/{0}-tfidf".format(mname), "rb") as f:
tfidf_vectorizer = pickle.load(f)
with open("features/{0}-tfreq".format(mname), "rb") as f:
tfreq_vectorizer = pickle.load(f)

self.bow = bow
self.bow_vectorizer = bow_vectorizer
self.tfidf_vectorizer = tfidf_vectorizer
self.tfreq_vectorizer = tfreq_vectorizer


except Exception as e:
raise e





def lookup(self,data):
return self.process(data)
Expand Down
7 changes: 2 additions & 5 deletions src/scripts/rte/mlp/train_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def str2bool(v):
parser.add_argument("--filtering",type=str, default=None)
args = parser.parse_args()

if os.path.exists("models"):
if not os.path.exists("models"):
os.mkdir("models")

logger.info("Loading DB {0}".format(args.db))
Expand Down Expand Up @@ -80,17 +80,14 @@ def str2bool(v):
test_ds.read()

train_feats, dev_feats, test_feats = f.load(train_ds, dev_ds, test_ds)
f.save_vocab(mname)

input_shape = train_feats[0].shape[1]

model = SimpleMLP(input_shape,100,3)

if gpu():
model.cuda()


if model_exists(mname) and os.getenv("TRAIN").lower() not in ["y","1","t","yes"]:
if model_exists(mname) and os.getenv("TRAIN","").lower() not in ["y","1","t","yes"]:
model.load_state_dict(torch.load("models/{0}.model".format(mname)))
else:
train(model, train_feats, 500, 1e-2, 90,dev_feats,early_stopping=EarlyStopping(mname))
Expand Down

0 comments on commit add471f

Please sign in to comment.