Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
thoglu committed Jan 2, 2024
2 parents 89fcd4d + 1e4bf50 commit afea31f
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 61 deletions.
158 changes: 101 additions & 57 deletions jammy_flows/helper_fns/contours.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def compute_contours(proportions, pdf_evals, areas, sample_points=None, manifold
## create 1 "joint" 1-d contour here
combined_list.append(numpy.array(contour)[...,None])
elif(sample_points.shape[1]==2):
contours_by_level = meander.euclidean_contours(sample_points, pdf_evals, levels)
contours_by_level = meander.planar_contours(sample_points, pdf_evals, levels)

elif(manifold=="sphere"):

Expand Down Expand Up @@ -331,11 +331,26 @@ def fake_plot_and_calc_eucl_contours(ax, colors, proportions, xvals, yvals, pdf_
"""
class custom_contour_generator(object):

def __init__(self, x,y, pdf_evals, areas, ax_obj):
self.joint_xy=np.concatenate([x[:,None],y[:,None]],axis=1)
self.pdf_evals=pdf_evals
self.areas=areas
self.ax_obj=ax_obj
def __init__(self, *args):

self.contour_type=args[0]
assert(self.contour_type=="euclidean" or self.contour_type=="zen_azi"), self.contour_type

if(len(args)==4):
self.has_precalculated_contours=True
self.contour_probs=args[1]
self.precalculated_contours=args[2]
self.ax_obj=args[3]
assert(len(self.contour_probs)==len(self.precalculated_contours))

elif(len(args)==6):
self.has_precalculated_contours=False
self.joint_xy=np.concatenate([args[1][:,None],args[2][:,None]],axis=1)
self.pdf_evals=args[3]
self.areas=args[4]
self.ax_obj=args[5]
else:
raise Exception("Require either 3 or 6 positional arguments for custom contour generator!")


def _get_azimuth_split_contours(self, c, is_azimuthal=True):
Expand Down Expand Up @@ -391,7 +406,7 @@ def _get_azimuth_split_contours(self, c, is_azimuthal=True):
if(num_splits==0):

return [c]

print("num splits .. ", num_splits)
return new_groups


Expand All @@ -407,57 +422,65 @@ def _make_isolated_kind(self,c):


def create_contour(self, contour_prob):
print("cprob ..", contour_prob)
if(self.has_precalculated_contours):
assert(contour_prob in self.contour_probs)

cindex=self.contour_probs.index(contour_prob)

contours=self.precalculated_contours[cindex]
else:
contours=compute_contours([contour_prob], self.pdf_evals, self.areas, sample_points=self.joint_xy, manifold="sphere")

contours=contours[0]

contours=compute_contours([contour_prob], self.pdf_evals, self.areas, sample_points=self.joint_xy, manifold="sphere")

contours=contours[0]

all_contours=[]

min_contour_len=2

for c in contours:
if(self.contour_type!="euclidean"):
all_contours=[]

sub_contours=self._get_azimuth_split_contours(c)
min_contour_len=2

for c in contours:

sub_contours=self._get_azimuth_split_contours(c)

for s in sub_contours:
if(len(s)>min_contour_len):
all_contours.append(s)
## rad to deg, zen->dec etc

for s in sub_contours:
if(len(s)>min_contour_len):
all_contours.append(s)
## rad to deg, zen->dec etc

transformed_contours=[]
for c in all_contours:
new_dec=90.0-c[:,0]*180.0/numpy.pi
new_ra=c[:,1]*180.0/numpy.pi
transformed_contours.append(numpy.concatenate([new_ra[:,None],new_dec[:,None]],axis=1))

## nans can appear for hidden points in current projection.. get rid of those...
transformed_contours=[]
for c in all_contours:
new_dec=90.0-c[:,0]*180.0/numpy.pi
new_ra=c[:,1]*180.0/numpy.pi
transformed_contours.append(numpy.concatenate([new_ra[:,None],new_dec[:,None]],axis=1))

## nans can appear for hidden points in current projection.. get rid of those...

contours=[]
contours=[]

for c in transformed_contours:
safe_c=self.ax_obj.wcs.all_world2pix(c,1)[~numpy.isnan(self.ax_obj.wcs.all_world2pix(c,1)[:,0])]

if(len(safe_c)>0):
#contours.append(safe_c)

sub_contours=self._get_azimuth_split_contours(safe_c, is_azimuthal=False)
for c in transformed_contours:

safe_c=self.ax_obj.wcs.all_world2pix(c,1)[~numpy.isnan(self.ax_obj.wcs.all_world2pix(c,1)[:,0])]

for sub_c in sub_contours:
if(len(sub_c)>min_contour_len): # only take contours longer than 2

contours.append(sub_c)
if(len(safe_c)>0):
#contours.append(safe_c)

sub_contours=self._get_azimuth_split_contours(safe_c, is_azimuthal=False)

for sub_c in sub_contours:
if(len(sub_c)>min_contour_len): # only take contours longer than 2

contours.append(sub_c)


#contours=[self.ax_obj.wcs.all_world2pix(c,1)[~numpy.isnan(self.ax_obj.wcs.all_world2pix(c,1))] for c in transformed_contours]
#print("transformed conts ", contours)

assert(type(contours)==list)


all_kinds=[]

for c in contours:

print(len(c))
## default is a repeating kind
new_kind=[1]+(len(c)-2)*[2]+[79]

Expand All @@ -471,7 +494,7 @@ def create_contour(self, contour_prob):



class CustomSphereContourSet(matplotlib.contour.ContourSet):
class ContourGenerator(matplotlib.contour.ContourSet):
"""
A custom contour set that has similar structure to QuadContourSet in matplotlib,
but is customized to work with variable resolution spherical data.
Expand All @@ -492,7 +515,6 @@ def _process_args(self, *args, corner_mask=None, algorithm=None, **kwargs):
self._maxs = args[0]._maxs
self._algorithm = args[0]._algorithm
else:
import contourpy

if algorithm is None:
algorithm = mpl.rcParams['contour.algorithm']
Expand All @@ -507,18 +529,40 @@ def _process_args(self, *args, corner_mask=None, algorithm=None, **kwargs):
else:
corner_mask = mpl.rcParams['contour.corner_mask']
self._corner_mask = corner_mask

assert(len(args)==4), args

x = args[0]
y = args[1]
log_evals = args[2]
areas = args[3]


self.plotted_clabels=False
print("INIT contour generator .. ", args[0])
if(len(args)==5):

contour_type=args[0]
x = args[1]
y = args[2]
log_evals = args[3]
areas = args[4]

self.zmin=min(log_evals)
self.zmax=max(log_evals)

## contours have to be generated first
contour_generator = custom_contour_generator(contour_type, x,y,log_evals, areas, self.axes)

self.zmin=min(log_evals)
self.zmax=max(log_evals)
elif(len(args)==3):

contour_type=args[0]
contour_probs=args[1]
given_contours = args[2]

self.zmin=0.001
self.zmax=1.0

## fake points
x=numpy.linspace(0,1,10)
y=numpy.linspace(0,1,10)

## contours are pre-calculated and only handed over for plotting
contour_generator = custom_contour_generator(contour_type, contour_probs, given_contours, self.axes)

contour_generator = custom_contour_generator(x,y,log_evals, areas, self.axes)

t = self.get_transform()

Expand Down Expand Up @@ -646,4 +690,4 @@ def _initialize_x_y(self, z):
y = y0 + (np.arange(Ny) + 0.5) * dy
if self.origin == 'upper':
y = y[::-1]
return np.meshgrid(x, y)
return np.meshgrid(x, y)
4 changes: 3 additions & 1 deletion jammy_flows/helper_fns/plotting/spherical.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,9 @@ def _plot_multiresolution_healpy(eval_positions,

used_contour_colors=used_contour_colors[:len(contour_probs)]

ret=CustomSphereContourSet(ax,
ret=ContourGenerator(
ax,
"zen_azi",
eval_positions[:,0],
eval_positions[:,1],
pdf_evals,
Expand Down
4 changes: 2 additions & 2 deletions jammy_flows/layers/euclidean/euclidean_do_nothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
normal_dist=tdist.Normal(0, 1)

class euclidean_do_nothing(euclidean_base.euclidean_base):
def __init__(self, dimension, use_permanent_parameters=True):
def __init__(self, dimension, use_permanent_parameters=True, add_offset=0):
"""
Identitiy transformation. Symbol "x"
"""

super().__init__(dimension=dimension, use_permanent_parameters=use_permanent_parameters, model_offset=0)
super().__init__(dimension=dimension, use_permanent_parameters=use_permanent_parameters, model_offset=add_offset)

def _flow_mapping(self, inputs, extra_inputs=None):

Expand Down
3 changes: 2 additions & 1 deletion jammy_flows/main/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3623,7 +3623,7 @@ def newton_iter(arg, p, normed_length_summed_pts):

# a simple sampling is typically faster than whole entropy calculation, so this might be a viable alternative

samples,sample_logprobs,_,_=self.sample(conditional_input=data_summary_repeated, samplesize=samplesize, device=used_device, dtype=used_dtype, force_embedding_coordinates=True)
samples,_,sample_logprobs,_=self.sample(conditional_input=data_summary_repeated, samplesize=samplesize, device=used_device, dtype=used_dtype, force_embedding_coordinates=True)

target_dim_embedded=self.total_target_dim_embedded

Expand All @@ -3642,6 +3642,7 @@ def newton_iter(arg, p, normed_length_summed_pts):

else:
## we have sample_logprobs
print("sample logprobs shape", sample_logprobs.shape)
reshaped_log_pdfs=sample_logprobs.reshape(initial_batch_size, samplesize)

index_mask=torch.argmax(reshaped_log_pdfs, dim=1)
Expand Down

0 comments on commit afea31f

Please sign in to comment.