Skip to content

Commit

Permalink
main: Added option to disable special treatment of JAABA PFFs when n=…
Browse files Browse the repository at this point in the history
…=2. Refactored feat_augment.m to ease precomputing of PFF names.
  • Loading branch information
adamltaylor committed Oct 29, 2024
1 parent 4d21963 commit f9e2181
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 84 deletions.
13 changes: 8 additions & 5 deletions tracking/core_tracker_compute_features.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ function core_tracker_compute_features(output_feature_file_name, output_features

% write feat file
trk = load_anonymous(input_tracking_file_name) ;
feat = feat_compute(trk, calibration) ;
if isfield(options, 'do_compute_relative_features') ,
do_compute_relative_features = options.do_compute_relative_features ;
else
do_compute_relative_features = true ; % for backwards-compatibility
end
feat = feat_compute(trk, calibration, do_compute_relative_features) ;
save(output_feature_file_name,'feat','-v7.3') ;

% save csv files
if options.save_xls
%output_xls_file_name = [input_tracking_file_name(1:end-10) '-trackfeat'];
names = [trk.names feat.names] ;
data = nan(size(trk.data,1),size(trk.data,2),numel(names)) ;
data(:,:,1:size(trk.data,3)) = trk.data ;
Expand All @@ -23,10 +27,9 @@ function core_tracker_compute_features(output_feature_file_name, output_features

% write JAABA folders
if options.save_JAABA
% JAABA_dir = [input_tracking_file_name(1:end-10) '-JAABA'];
% augment features (with log, norms, and derivatives)
feat = feat_augment(feat);
writeJAABA(input_tracking_file_name, input_video_file_name, trk, feat, calibration, output_jaaba_folder_name) ;
aug_feat = feat_augment(feat);
writeJAABA(input_tracking_file_name, input_video_file_name, trk, aug_feat, calibration, output_jaaba_folder_name) ;
end

end
68 changes: 4 additions & 64 deletions tracking/feat_augment.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,68 +14,8 @@
% Normalizes features that are fly-variant (such as size)
% Adds 1st and 2nd derivatives of each feature to the data matrix
%
function feat = feat_augment(feat)
% take log of the following features
logfeat = {'vel','ang_vel','min_wing_ang','max_wing_ang','fg_body_ratio'};
epsi = 0.001;
% normalize the following features
normfeat = {'mean_wing_length','axis_ratio','contrast'};
% use floowing wavelets to take 1st and 2nd derivatives of features
wavelets = [0 0.5 -0.5 0 0; % tight 1st deriv gauss
0 0.25 -0.5 0.25 0]; % tight 2nd deriv gauss
% initialize data
data = feat.data;
names = feat.names;
n_flies = size(data,1);
n_frames = size(data,2);
n_feats = size(data,3);
learn_data = zeros(n_flies,n_frames,n_feats*3);
% augment each feature for all flies
for s=1:n_flies
for i=1:n_feats
% median normalize certain features
if ismember(names{i},normfeat)
denom = nanmedian(data(s,:,i));
if denom ~= 0 && ~isnan(denom)
data(s,:,i) = data(s,:,i)/denom;
end
end
% take the log of certain features
if ismember(names{i},logfeat)
if strcmp(names{i},'ang_vel')
data(s,:,i) = log10(data(s,:,i)+epsi*0.001);
else
data(s,:,i) = log10(data(s,:,i)+epsi);
end
end
learn_data(s,:,i) = data(s,:,i);
% apply wavelets to each feature vector
idx = n_feats + (i-1)*2;
responses = conv(learn_data(s,:,i),wavelets(1,:),'valid');
buff_left = floor((n_frames-numel(responses))/2);
buff_right = ceil((n_frames-numel(responses))/2);
learn_data(s,1:buff_left,idx+1) = responses(1);
learn_data(s,buff_left+1:end-buff_right,idx+1) = responses;
learn_data(s,end-buff_right+1:end,idx+1) = responses(end);
responses = conv(learn_data(s,:,i),wavelets(2,:),'valid');
buff_left = floor((n_frames-numel(responses))/2);
buff_right = ceil((n_frames-numel(responses))/2);
learn_data(s,1:buff_left,idx+2) = responses(1);
learn_data(s,buff_left+1:end-buff_right,idx+2) = responses;
learn_data(s,end-buff_right+1:end,idx+2) = responses(end);
end
end
% update names to match their augmentation
for i=1:numel(names)
if ismember(names{i},normfeat)
names{i} = ['norm_' names{i}];
end
if ismember(names{i},logfeat)
names{i} = ['log_' names{i}];
end
names{end+1} = [names{i} '_diff1'];
names{end+1} = [names{i} '_diff2'];
end
feat.names = names;
feat.data = learn_data;
function aug_feat = feat_augment(feat)
aug_data = feat_augment_data(feat.data, feat.names) ;
aug_names = feat_augment_names(feat.names) ;
aug_feat = struct('names', {aug_names}, 'units', {feat.units}, 'data', {aug_data}) ;
end
46 changes: 46 additions & 0 deletions tracking/feat_augment_data.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
function learn_data = feat_augment_data(data, names)
[logfeat, normfeat] = feat_augment_log_and_norm() ;
epsi = 0.001;
% use floowing wavelets to take 1st and 2nd derivatives of features
wavelets = [0 0.5 -0.5 0 0; % tight 1st deriv gauss
0 0.25 -0.5 0.25 0]; % tight 2nd deriv gauss
n_flies = size(data,1);
n_frames = size(data,2);
n_feats = size(data,3);
learn_data = zeros(n_flies,n_frames,n_feats*3);
% augment each feature for all flies
for s=1:n_flies
for i=1:n_feats
% median normalize certain features
if ismember(names{i},normfeat)
denom = nanmedian(data(s,:,i));
if denom ~= 0 && ~isnan(denom)
data(s,:,i) = data(s,:,i)/denom;
end
end
% take the log of certain features
if ismember(names{i},logfeat)
if strcmp(names{i},'ang_vel')
data(s,:,i) = log10(data(s,:,i)+epsi*0.001);
else
data(s,:,i) = log10(data(s,:,i)+epsi);
end
end
learn_data(s,:,i) = data(s,:,i);
% apply wavelets to each feature vector
idx = n_feats + (i-1)*2;
responses = conv(learn_data(s,:,i),wavelets(1,:),'valid');
buff_left = floor((n_frames-numel(responses))/2);
buff_right = ceil((n_frames-numel(responses))/2);
learn_data(s,1:buff_left,idx+1) = responses(1);
learn_data(s,buff_left+1:end-buff_right,idx+1) = responses;
learn_data(s,end-buff_right+1:end,idx+1) = responses(end);
responses = conv(learn_data(s,:,i),wavelets(2,:),'valid');
buff_left = floor((n_frames-numel(responses))/2);
buff_right = ceil((n_frames-numel(responses))/2);
learn_data(s,1:buff_left,idx+2) = responses(1);
learn_data(s,buff_left+1:end-buff_right,idx+2) = responses;
learn_data(s,end-buff_right+1:end,idx+2) = responses(end);
end
end
end
6 changes: 6 additions & 0 deletions tracking/feat_augment_log_and_norm.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
function [logfeat, normfeat] = feat_augment_log_and_norm()
% take log of the following features
logfeat = {'vel','ang_vel','min_wing_ang','max_wing_ang','fg_body_ratio'};
% normalize the following features
normfeat = {'mean_wing_length','axis_ratio','contrast'};
end
20 changes: 20 additions & 0 deletions tracking/feat_augment_names.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
function aug_names = feat_augment_names(names)
[logfeat, normfeat] = feat_augment_log_and_norm() ;
% update names to match their augmentation
simple_aug_names = cell(1,0) ;
diff_aug_names = cell(1,0) ;
for i=1:numel(names)
name = names{i} ;
if ismember(names{i},normfeat)
aug_name = ['norm_' name];
elseif ismember(names{i},logfeat)
aug_name = ['log_' name];
else
aug_name = names{i} ;
end
simple_aug_names{1,i} = aug_name ;
diff_aug_names{1,end+1} = [aug_name '_diff1']; %#ok<AGROW>
diff_aug_names{1,end+1} = [aug_name '_diff2']; %#ok<AGROW>
end
aug_names = horzcat(simple_aug_names, diff_aug_names) ;
end
26 changes: 11 additions & 15 deletions tracking/feat_compute.m
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,15 @@
%
% All features are computed to be independent of resolution (FPS,PPM)
%
function feat = feat_compute(trk, calib)
function feat = feat_compute(trk, calib, do_compute_relative)
if ~exist('do_compute_relative', 'var') || isempty(do_compute_relative) ,
do_compute_relative = true ; % true for backwards-compatibility
end
% store video resolution parameters for later normalization
pix_per_mm = calib.PPM;
FPS = calib.FPS;
% names of features to be computed
personal_feat = {'vel','ang_vel','min_wing_ang','max_wing_ang',...
'mean_wing_length','axis_ratio','fg_body_ratio','contrast'};
enviro_feat = {'dist_to_wall'};
relative_feat = {'dist_to_other','angle_between','facing_angle','leg_dist'};
% units of features to be computed
personal_units = {'mm/s','rad/s','rad','rad','mm','ratio','ratio',''};
enviro_units = {'mm'};
relative_units = {'mm','rad','rad','mm'};
% names, units of features to be computed
[personal_feat, enviro_feat, relative_feat, personal_units, enviro_units, relative_units] = feat_names_and_units() ;
% kernel for smoothing output
smooth_kernel = [1 2 1]/4;
% note which flies share a chamber ("buddies") to keep track of whether
Expand All @@ -46,7 +42,7 @@
obj_count(c) = numel(trk.flies_in_chamber{c});
end
n_objs = max(obj_count);
if n_objs == 2
if do_compute_relative && (n_objs == 2)
buddy = zeros(1,n_flies);
for i=1:numel(trk.flies_in_chamber)
flies = trk.flies_in_chamber{i};
Expand All @@ -59,15 +55,15 @@
end
else
n_objs = n_flies;
if n_objs == 2
if do_compute_relative && (n_objs == 2)
buddy = [2 1];
bud_complete = false(size(buddy));
end
end
% initialize features
n_frames = size(trk.data,2);
n_trkfeat = size(trk.data,3);
n_feats = numel(personal_feat) + (n_objs==2)*numel(relative_feat) + numel(enviro_feat);
n_feats = numel(personal_feat) + do_compute_relative*(n_objs==2)*numel(relative_feat) + numel(enviro_feat);
track = trk.data(:,:,:);
features = nan(n_flies,n_frames,n_feats);
% compute distance to chambers for all pixels
Expand Down Expand Up @@ -168,7 +164,7 @@
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
center = [track(s,:,1)' track(s,:,2)'];
vec_rot = [cos(ori)' -sin(ori)'];
if n_objs==2
if do_compute_relative && n_objs==2
if bud_complete(s), continue; end % only need to calculate these once
bud = buddy(s);
if bud==0, continue; end % this fly has no buddy
Expand Down Expand Up @@ -213,7 +209,7 @@
% store variables in feat structure
names = [personal_feat enviro_feat];
units = [personal_units enviro_units];
if n_objs==2
if do_compute_relative && n_objs==2
names = [names relative_feat];
units = [units relative_units];
end
Expand Down
1 change: 1 addition & 0 deletions tracking/tracker_default_options.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@
options.do_recompute_tracking = false ;
options.min_fractional_arena_size = 0.9 ;
options.max_fractional_arena_size = 1.1 ;
options.do_compute_relative_features = true ;
end
11 changes: 11 additions & 0 deletions tracking/utilities/feat_names_and_units.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [personal_feat, enviro_feat, relative_feat, personal_units, enviro_units, relative_units] = feat_names_and_units()
% names of features to be computed
personal_feat = {'vel','ang_vel','min_wing_ang','max_wing_ang',...
'mean_wing_length','axis_ratio','fg_body_ratio','contrast'};
enviro_feat = {'dist_to_wall'};
relative_feat = {'dist_to_other','angle_between','facing_angle','leg_dist'};
% units of features to be computed
personal_units = {'mm/s','rad/s','rad','rad','mm','ratio','ratio',''};
enviro_units = {'mm'};
relative_units = {'mm','rad','rad','mm'};
end
11 changes: 11 additions & 0 deletions tracking/utilities/flytracker_jaaba_features_and_units.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
function [personal_feat, enviro_feat, relative_feat, personal_units, enviro_units, relative_units] = flytracker_jaaba_features_and_units()
% names of features to be computed
personal_feat = {'vel','ang_vel','min_wing_ang','max_wing_ang',...
'mean_wing_length','axis_ratio','fg_body_ratio','contrast'};
enviro_feat = {'dist_to_wall'};
relative_feat = {'dist_to_other','angle_between','facing_angle','leg_dist'};
% units of features to be computed
personal_units = {'mm/s','rad/s','rad','rad','mm','ratio','ratio',''};
enviro_units = {'mm'};
relative_units = {'mm','rad','rad','mm'};
end

0 comments on commit f9e2181

Please sign in to comment.