Skip to content

Commit

Permalink
CLIC dataset v2.3.0: fix target/truth momentum, st=1 to target more i…
Browse files Browse the repository at this point in the history
…nclusive (#352)

* fix target by using only st=1 particles from ROOT

* update datasets to 2.3.0

* update dataset

---------

Co-authored-by: Joosep Pata <[email protected]>
  • Loading branch information
jpata and Joosep Pata authored Oct 22, 2024
1 parent e2433d7 commit fe4cbd5
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 227 deletions.
164 changes: 88 additions & 76 deletions mlpf/data/clic/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
import uproot
import vector
import tqdm
import pyhepmc
import bz2
import fastjet
from scipy.sparse import coo_matrix
import math
Expand Down Expand Up @@ -273,7 +271,7 @@ def get_calohit_matrix_and_genadj(hit_data, calohit_links, iev, collectionIDs):

return (
hit_feature_matrix,
(genparticle_to_hit_matrix_coo0, genparticle_to_hit_matrix_coo1, genparticle_to_hit_matrix_w),
(np.array(genparticle_to_hit_matrix_coo0), np.array(genparticle_to_hit_matrix_coo1), np.array(genparticle_to_hit_matrix_w)),
hit_idx_local_to_global,
)

Expand Down Expand Up @@ -341,6 +339,9 @@ def gen_to_features(prop_data, iev):
"gp_to_track": np.zeros(len(gen_arr["PDG"]), dtype=np.float64),
"gp_to_cluster": np.zeros(len(gen_arr["PDG"]), dtype=np.float64),
"jet_idx": np.zeros(len(gen_arr["PDG"]), dtype=np.int64),
"daughters_begin": gen_arr["daughters_begin"],
"daughters_end": gen_arr["daughters_end"],
"index": prop_data["MCParticles#1.index"][iev],
}


Expand Down Expand Up @@ -474,75 +475,108 @@ def filter_adj(adj, all_to_filtered):
return np.array(i0s_new), np.array(i1s_new), np.array(ws_new)


# loop over status 1 particles and collect the hits of their immediate daughters
# genparticle_to_hit: tuple with 3 arrays (genparticle_indices, hit_indices, weights)
# genparticle_to_trk: tuple with 3 arrays (genparticle_indices, track_indices, weights)
def add_daughters_to_status1(gen_features, genparticle_to_hit, genparticle_to_trk):
mask_status1 = gen_features["generatorStatus"] == 1
dau_beg = gen_features["daughters_begin"]
dau_end = gen_features["daughters_end"]
dau_ind = gen_features["index"]
genparticle_to_hit_additional_gp = []
genparticle_to_hit_additional_hit = []
genparticle_to_hit_additional_w = []
genparticle_to_trk_additional_gp = []
genparticle_to_trk_additional_trk = []
genparticle_to_trk_additional_w = []
for idx_st1 in np.where(mask_status1)[0]:
pdg = abs(gen_features["PDG"][idx_st1])
if pdg not in [12, 14, 16]:
db = dau_beg[idx_st1]
de = dau_end[idx_st1]
daus = dau_ind[db:de]
for dau in daus:
dau_hit_idx = genparticle_to_hit[1][genparticle_to_hit[0] == dau]
dau_hit_w = genparticle_to_hit[2][genparticle_to_hit[0] == dau]
for dh_idx, dh_w in zip(dau_hit_idx, dau_hit_w):
genparticle_to_hit_additional_gp.append(idx_st1)
genparticle_to_hit_additional_hit.append(dh_idx)
genparticle_to_hit_additional_w.append(dh_w)

dau_trk_idx = genparticle_to_trk[1][genparticle_to_trk[0] == dau]
dau_trk_w = genparticle_to_trk[2][genparticle_to_trk[0] == dau]
for dt_idx, dt_w in zip(dau_trk_idx, dau_trk_w):
genparticle_to_trk_additional_gp.append(idx_st1)
genparticle_to_trk_additional_trk.append(dt_idx)
genparticle_to_trk_additional_w.append(dt_w)

genparticle_to_hit = (
np.concatenate([genparticle_to_hit[0], genparticle_to_hit_additional_gp]),
np.concatenate([genparticle_to_hit[1], genparticle_to_hit_additional_hit]),
np.concatenate([genparticle_to_hit[2], genparticle_to_hit_additional_w]),
)
genparticle_to_trk = (
np.concatenate([genparticle_to_trk[0], genparticle_to_trk_additional_gp]),
np.concatenate([genparticle_to_trk[1], genparticle_to_trk_additional_trk]),
np.concatenate([genparticle_to_trk[2], genparticle_to_trk_additional_w]),
)
return genparticle_to_hit, genparticle_to_trk


def get_genparticles_and_adjacencies(prop_data, hit_data, calohit_links, sitrack_links, iev, collectionIDs):
gen_features = gen_to_features(prop_data, iev)
hit_features, genparticle_to_hit, hit_idx_local_to_global = get_calohit_matrix_and_genadj(hit_data, calohit_links, iev, collectionIDs)
hit_to_cluster = hit_cluster_adj(prop_data, hit_idx_local_to_global, iev)
cluster_features = cluster_to_features(prop_data, hit_features, hit_to_cluster, iev)
track_features = track_to_features(prop_data, iev)
genparticle_to_track = genparticle_track_adj(sitrack_links, iev)
genparticle_to_trk = genparticle_track_adj(sitrack_links, iev)

# collect hits of st=1 daughters to the st=1 particles
mask_status1 = gen_features["generatorStatus"] == 1
genparticle_to_hit, genparticle_to_trk = add_daughters_to_status1(gen_features, genparticle_to_hit, genparticle_to_trk)

n_gp = awkward.count(gen_features["PDG"])
n_track = awkward.count(track_features["type"])
n_hit = awkward.count(hit_features["type"])
n_cluster = awkward.count(cluster_features["type"])

if len(genparticle_to_track[0]) > 0:
gp_to_track = (
coo_matrix((genparticle_to_track[2], (genparticle_to_track[0], genparticle_to_track[1])), shape=(n_gp, n_track)).max(axis=1).todense()
)
if len(genparticle_to_trk[0]) > 0:
gp_to_track = coo_matrix((genparticle_to_trk[2], (genparticle_to_trk[0], genparticle_to_trk[1])), shape=(n_gp, n_track)).max(axis=1).todense()
else:
gp_to_track = np.zeros((n_gp, 1))

gp_to_calohit = coo_matrix((genparticle_to_hit[2], (genparticle_to_hit[0], genparticle_to_hit[1])), shape=(n_gp, n_hit))
calohit_to_cluster = coo_matrix((hit_to_cluster[2], (hit_to_cluster[0], hit_to_cluster[1])), shape=(n_hit, n_cluster))
gp_to_cluster = (gp_to_calohit * calohit_to_cluster).sum(axis=1)

# 60% of the hits of a track must come from the genparticle
gp_in_tracker = np.array(gp_to_track >= 0.6)[:, 0]
# 20% of the hits of a track must come from the genparticle
gp_in_tracker = np.array(gp_to_track >= 0.2)[:, 0]

# at least 10% of the energy of the genparticle should be matched to a calorimeter cluster
gp_in_calo = (np.array(gp_to_cluster)[:, 0] / gen_features["energy"]) > 0.1
# at least 5% of the energy of the genparticle should be matched to a calorimeter cluster
gp_in_calo = (np.array(gp_to_cluster)[:, 0] / gen_features["energy"]) > 0.05

gp_interacted_with_detector = gp_in_tracker | gp_in_calo

gen_features["gp_to_track"] = np.asarray(gp_to_track)[:, 0]
gen_features["gp_to_cluster"] = np.asarray(gp_to_cluster)[:, 0]

mask_status1 = (gen_features["energy"] > 0.001) & (gen_features["generatorStatus"] == 1)
mask_visible = np.asarray(mask_status1 & gp_interacted_with_detector)

# some status=1 particles from Pythia did not interact with the detector
# we could look for their status=0 children, but daughter indices are messed up
# find dr-nearby status=0 particles that interacted with the detector and add them to the list of reconstructable particles
not_visible_particles = np.argsort(gen_features["pt"][mask_status1 & ~mask_visible])[::-1]
nvp_pdg = gen_features["PDG"][mask_status1 & ~mask_visible][not_visible_particles]
nvp_pt = gen_features["pt"][mask_status1 & ~mask_visible][not_visible_particles]
nvp_eta = gen_features["eta"][mask_status1 & ~mask_visible][not_visible_particles]
nvp_phi = gen_features["phi"][mask_status1 & ~mask_visible][not_visible_particles]
for pdg, pt, eta, phi in zip(nvp_pdg, nvp_pt, nvp_eta, nvp_phi):
if (abs(pdg) != 12) & (abs(pdg) != 14) & (abs(pdg) != 16):
dr = deltar(eta, phi, gen_features["eta"], gen_features["phi"])
gen_mask = (dr < 0.2) & (gen_features["generatorStatus"] != 1) & gp_interacted_with_detector
mask_visible[gen_mask] = 1
# print(pdg, pt, eta, phi, gen_features["pt"][gen_mask], gp_interacted_with_detector[gen_mask], np.where(gen_mask))

# print("gps total={} visible={}".format(n_gp, np.sum(mask_visible)))
mask_visible = awkward.to_numpy(mask_status1 & gp_interacted_with_detector)

idx_all_masked = np.where(mask_visible)[0]
genpart_idx_all_to_filtered = {idx_all: idx_filtered for idx_filtered, idx_all in enumerate(idx_all_masked)}

gen_features = awkward.Record({feat: gen_features[feat][mask_visible] for feat in gen_features.keys()})

genparticle_to_hit = filter_adj(genparticle_to_hit, genpart_idx_all_to_filtered)
genparticle_to_track = filter_adj(genparticle_to_track, genpart_idx_all_to_filtered)
genparticle_to_trk = filter_adj(genparticle_to_trk, genpart_idx_all_to_filtered)

return EventData(
gen_features,
hit_features,
cluster_features,
track_features,
genparticle_to_hit,
genparticle_to_track,
genparticle_to_trk,
hit_to_cluster,
([], []),
)
Expand Down Expand Up @@ -810,23 +844,6 @@ def compute_jets(particles_p4, min_pt=jet_ptcut, with_indices=False):
return ret


def load_hepmc(hepmc_file_path):
events = []
with pyhepmc.open(bz2.BZ2File(hepmc_file_path, "rb")) as f:
for event in f:
parts = [p for p in event.particles if p.status == 1]
parts = {
"MCParticles.momentum.x": [p.momentum.x for p in parts],
"MCParticles.momentum.y": [p.momentum.y for p in parts],
"MCParticles.momentum.z": [p.momentum.z for p in parts],
"MCParticles.mass": [p.momentum.m() for p in parts],
"MCParticles.PDG": [p.pid for p in parts],
}
events.append(parts)
events = awkward.from_iter(events)
return events


def process_one_file(fn, ofn):

# output exists, do not recreate
Expand All @@ -838,19 +855,6 @@ def process_one_file(fn, ofn):
fi = uproot.open(fn)
arrs = fi["events"]

# load .hepmc file corresponding to the .root file
hepmc_file_path = fn.replace("/root/", "/sim/").replace(".root", ".hepmc.bz2").replace("reco_", "sim_")
hepmc_mcp = load_hepmc(hepmc_file_path)

# compute Pythia jets and MET with visible particles
hepmc_p4 = get_p4(hepmc_mcp, "MCParticles")
hepmc_pid = np.abs(hepmc_mcp["MCParticles.PDG"])
hepmc_p4_visible = hepmc_p4[(hepmc_pid != 12) & (hepmc_pid != 14) & (hepmc_pid != 16)]
met_hepmc = compute_met(hepmc_p4_visible)
genjets_hepmc = compute_jets(hepmc_p4_visible)

assert len(hepmc_mcp) == arrs.num_entries

collectionIDs = {
k: v
for k, v in zip(
Expand All @@ -869,6 +873,9 @@ def process_one_file(fn, ofn):
"MCParticles.charge",
"MCParticles.generatorStatus",
"MCParticles.simulatorStatus",
"MCParticles.daughters_begin",
"MCParticles.daughters_end",
"MCParticles#1.index",
track_coll,
"SiTracks_1",
"PandoraClusters",
Expand All @@ -878,17 +885,6 @@ def process_one_file(fn, ofn):
]
)

# fix status 1 particles momentum from hepmc
eq = prop_data["MCParticles.PDG"][prop_data["MCParticles.generatorStatus"] == 1] == hepmc_mcp["MCParticles.PDG"]
assert np.all(eq)

msk = prop_data["MCParticles.generatorStatus"] == 1
counts = awkward.count(msk, axis=1)
for branch in ["MCParticles.momentum.x", "MCParticles.momentum.y", "MCParticles.momentum.z"]:
arr_as_np = np.asarray(awkward.flatten(prop_data[branch]))
arr_as_np[awkward.flatten(msk)] = awkward.flatten(hepmc_mcp[branch])
prop_data[branch] = awkward.unflatten(arr_as_np, counts)

calohit_links = arrs.arrays(
[
"CalohitMCTruthLink.weight",
Expand Down Expand Up @@ -923,6 +919,22 @@ def process_one_file(fn, ofn):
"MUON": arrs["MUON"].array(),
}

# Compute truth MET and jets from status=1 pythia particles
mc_pdg = np.abs(prop_data["MCParticles.PDG"])
mc_st1_mask = (prop_data["MCParticles.generatorStatus"] == 1) & (mc_pdg != 12) & (mc_pdg != 14) & (mc_pdg != 16)
mc_st1_p4 = vector.awk(
awkward.zip(
{
"px": prop_data["MCParticles.momentum.x"][mc_st1_mask],
"py": prop_data["MCParticles.momentum.y"][mc_st1_mask],
"pz": prop_data["MCParticles.momentum.z"][mc_st1_mask],
"mass": prop_data["MCParticles.mass"][mc_st1_mask],
}
)
)
met_st1 = compute_met(mc_st1_p4)
genjets_st1 = compute_jets(mc_st1_p4)

ret = []
for iev in tqdm.tqdm(range(arrs.num_entries), total=arrs.num_entries):

Expand Down Expand Up @@ -1056,8 +1068,8 @@ def process_one_file(fn, ofn):
"ytarget_cluster": ytarget_cluster,
"ycand_track": ycand_track,
"ycand_cluster": ycand_cluster,
"genmet": met_hepmc[iev],
"genjet": get_feature_matrix(genjets_hepmc[iev], ["pt", "eta", "phi", "energy"]),
"genmet": met_st1[iev],
"genjet": get_feature_matrix(genjets_st1[iev], ["pt", "eta", "phi", "energy"]),
"targetjet": get_feature_matrix(target_jets, ["pt", "eta", "phi", "energy"]),
}
)
Expand Down
6 changes: 3 additions & 3 deletions mlpf/data/clic/postprocessing_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def write_script(infiles, outpath):

samples = [
("/local/joosep/clic_edm4hep/2024_07/p8_ee_qq_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_qq_ecm380/"),
# ("/local/joosep/clic_edm4hep/2024_07/p8_ee_tt_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_tt_ecm380/"),
("/local/joosep/clic_edm4hep/2024_07/p8_ee_tt_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_tt_ecm380/"),
("/local/joosep/clic_edm4hep/2024_07/p8_ee_WW_fullhad_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_WW_fullhad_ecm380/"),
# ("/local/joosep/clic_edm4hep/2024_07/p8_ee_ZH_Htautau_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_ZH_Htautau_ecm380/"),
# ("/local/joosep/clic_edm4hep/2024_07/p8_ee_Z_Ztautau_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_Z_Ztautau_ecm380/"),
("/local/joosep/clic_edm4hep/2024_07/p8_ee_ZH_Htautau_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_ZH_Htautau_ecm380/"),
("/local/joosep/clic_edm4hep/2024_07/p8_ee_Z_Ztautau_ecm380/root/", "/local/joosep/mlpf/clic_edm4hep/p8_ee_Z_Ztautau_ecm380/"),
]

ichunk = 1
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/clic_pf_edm4hep/qq.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class ClicEdmQqPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.2.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "update stats, move to 380 GeV",
Expand All @@ -38,6 +38,7 @@ class ClicEdmQqPf(tfds.core.GeneratorBasedBuilder):
"2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.1.0": "Bump dataset size",
"2.2.0": "New target definition, fix truth jets, add targetjets and jet idx",
"2.3.0": "Fix target/truth momentum, st=1 more inclusive: PR352",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/clic_pf_edm4hep/ttbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.2.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.0.0": "Initial release.",
"1.1.0": "update stats, move to 380 GeV",
Expand All @@ -37,6 +37,7 @@ class ClicEdmTtbarPf(tfds.core.GeneratorBasedBuilder):
"2.0.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.1.0": "Bump dataset size",
"2.2.0": "New target definition, fix truth jets, add targetjets and jet idx",
"2.3.0": "Fix target/truth momentum, st=1 more inclusive: PR352",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand Down
3 changes: 2 additions & 1 deletion mlpf/heptfds/clic_pf_edm4hep/ww_fullhad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@


class ClicEdmWwFullhadPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.2.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.3.0": "Update stats to ~1M events",
"1.4.0": "Fix ycand matching",
"1.5.0": "Regenerate with ARRAY_RECORD",
"2.1.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.2.0": "New target definition, fix truth jets, add targetjets and jet idx",
"2.3.0": "Fix target/truth momentum, st=1 more inclusive: PR352",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand Down
6 changes: 4 additions & 2 deletions mlpf/heptfds/clic_pf_edm4hep/z.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@


class ClicEdmZTautauPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.1.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.3.0": "First version",
"1.4.0": "Fix ycand matching",
"1.5.0": "Regenerate with ARRAY_RECORD",
"2.1.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.3.0": "Fix target/truth momentum, st=1 more inclusive: PR352",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand All @@ -55,10 +56,11 @@ def _info(self) -> tfds.core.DatasetInfo:
),
dtype=tf.float32,
),
"ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"ytarget": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"genmet": tfds.features.Scalar(dtype=tf.float32),
"genjets": tfds.features.Tensor(shape=(None, 4), dtype=tf.float32),
"targetjets": tfds.features.Tensor(shape=(None, 4), dtype=tf.float32),
}
),
supervised_keys=None,
Expand Down
6 changes: 4 additions & 2 deletions mlpf/heptfds/clic_pf_edm4hep/zh.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@


class ClicEdmZhTautauPf(tfds.core.GeneratorBasedBuilder):
VERSION = tfds.core.Version("2.1.0")
VERSION = tfds.core.Version("2.3.0")
RELEASE_NOTES = {
"1.3.0": "First version",
"1.4.0": "Fix ycand matching",
"1.5.0": "Regenerate with ARRAY_RECORD",
"2.1.0": "Add ispu, genjets, genmet; disable genjet_idx; truth def not based on gp.status==1",
"2.3.0": "Fix target/truth momentum, st=1 more inclusive: PR352",
}
MANUAL_DOWNLOAD_INSTRUCTIONS = """
For the raw input files in ROOT EDM4HEP format, please see the citation above.
Expand All @@ -58,10 +59,11 @@ def _info(self) -> tfds.core.DatasetInfo:
),
dtype=tf.float32,
),
"ygen": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"ytarget": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"ycand": tfds.features.Tensor(shape=(None, len(Y_FEATURES)), dtype=tf.float32),
"genmet": tfds.features.Scalar(dtype=tf.float32),
"genjets": tfds.features.Tensor(shape=(None, 4), dtype=tf.float32),
"targetjets": tfds.features.Tensor(shape=(None, 4), dtype=tf.float32),
}
),
supervised_keys=None,
Expand Down
Loading

0 comments on commit fe4cbd5

Please sign in to comment.