Skip to content

Commit

Permalink
Added greedy policy based on alternative familiarity definition
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanKerstjens committed Dec 16, 2023
1 parent a54c4d3 commit ea95003
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 15 deletions.
7 changes: 4 additions & 3 deletions AutoCorrectMolecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def ParseArgs():
parser.add_argument("smiles", type=str,
help="Input molecule SMILES string.")
parser.add_argument("-p", "--policy", type=str, default="MLR",
choices=["Familiarity", "BFS", "DistanceNormalizedFamiliarity",
"UCT", "Astar", "MLR"],
choices=["BFS", "Familiarity1", "Familiarity2",
"DistanceNormalizedFamiliarity", "UCT", "Astar", "MLR"],
help="Tree search vertex selection policy type.")
parser.add_argument("-s", "--max_tree_size", type=int, default=1000,
help="Maximum tree size.")
Expand Down Expand Up @@ -52,8 +52,9 @@ def Main():
settings.attempt_environment_correction_with_atom_insertions = False

policy_types = {
"Familiarity": mac.Policy.Type.Familiarity,
"BFS": mac.Policy.Type.BFS,
"Familiarity1": mac.Policy.Type.Familiarity1,
"Familiarity2": mac.Policy.Type.Familiarity2,
"DistanceNormalizedFamiliarity": mac.Policy.Type.DistanceNormalizedFamiliarity,
"UCT": mac.Policy.Type.UCT,
"Astar": mac.Policy.Type.Astar,
Expand Down
33 changes: 23 additions & 10 deletions source/MoleculeAutoCorrect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,9 @@ class Vertex : public RDKit::ROMol {
};

double Familiarity2() const {
double n_atoms = getNumAtoms();
return (n_atoms - n_foreign_environments) / n_atoms;
double n_foreign_keys =
n_foreign_atoms + n_foreign_bonds + n_foreign_environments;
return 1.0 / (n_foreign_keys + 1.0);
};

bool IsFamiliar() const {
Expand All @@ -508,10 +509,14 @@ class std::hash<MoleculeAutoCorrect::Vertex> {

namespace MoleculeAutoCorrect {

double Familiarity(const Vertex& vertex) {
double Familiarity1(const Vertex& vertex) {
return vertex.Familiarity1();
};

double Familiarity2(const Vertex& vertex) {
return vertex.Familiarity2();
};

std::optional<Vertex> Expansion(Vertex& vertex) {
return vertex.Expand();
};
Expand Down Expand Up @@ -620,7 +625,8 @@ struct Constant {

enum class Type {
BFS,
Familiarity,
Familiarity1,
Familiarity2,
DistanceNormalizedFamiliarity,
Astar,
UCT,
Expand All @@ -635,8 +641,12 @@ struct BFS : GreedyPolicy<Vertex> {
GreedyPolicy<Vertex>(Objective::TopologicalSimilarity(root_molecule)) {};
};

struct Familiarity : GreedyPolicy<Vertex> {
Familiarity() : GreedyPolicy<Vertex>(Objective::Familiarity1) {};
struct Familiarity1 : GreedyPolicy<Vertex> {
Familiarity1() : GreedyPolicy<Vertex>(Objective::Familiarity1) {};
};

struct Familiarity2 : GreedyPolicy<Vertex> {
Familiarity2() : GreedyPolicy<Vertex>(Objective::Familiarity2) {};
};

struct DistanceNormalizedFamiliarity : GreedyPolicy<Vertex> {
Expand Down Expand Up @@ -727,7 +737,7 @@ MoleculeAutoCorrect::Result AutoCorrectMolecule(
TerminationPolicy(settings.n_solutions)) - 1;

auto top_vertices = tree_search.TopVertices(
Familiarity, settings.n_top_solutions);
Familiarity1, settings.n_top_solutions);
const auto& vertex_depths = tree_search.GetVertexDepths();
for (auto [v, familiarity] : top_vertices) {
const Vertex& vertex = tree_search.GetVertex(v);
Expand All @@ -749,8 +759,11 @@ MoleculeAutoCorrect::Result AutoCorrectMolecule(
case Policy::Type::BFS:
selection_policy = Policy::BFS(molecule);
break;
case Policy::Type::Familiarity:
selection_policy = Policy::Familiarity();
case Policy::Type::Familiarity1:
selection_policy = Policy::Familiarity1();
break;
case Policy::Type::Familiarity2:
selection_policy = Policy::Familiarity2();
break;
case Policy::Type::DistanceNormalizedFamiliarity:
selection_policy = Policy::DistanceNormalizedFamiliarity(molecule);
Expand All @@ -760,7 +773,7 @@ MoleculeAutoCorrect::Result AutoCorrectMolecule(
break;
case Policy::Type::UCT:
selection_policy = Policy::UCT(0.5);
reward_function = Familiarity;
reward_function = Familiarity1;
break;
case Policy::Type::MLR:
selection_policy = Policy::MLR(molecule);
Expand Down
6 changes: 4 additions & 2 deletions source/wrap/pyMoleculeAutoCorrect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ void WrapMoleculeAutoCorrect() {
python::enum_<MoleculeAutoCorrect::Policy::Type>("Type")
.value("BFS",
MoleculeAutoCorrect::Policy::Type::BFS)
.value("Familiarity",
MoleculeAutoCorrect::Policy::Type::Familiarity)
.value("Familiarity1",
MoleculeAutoCorrect::Policy::Type::Familiarity1)
.value("Familiarity2",
MoleculeAutoCorrect::Policy::Type::Familiarity2)
.value("DistanceNormalizedFamiliarity",
MoleculeAutoCorrect::Policy::Type::DistanceNormalizedFamiliarity)
.value("Astar",
Expand Down

0 comments on commit ea95003

Please sign in to comment.