-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBaseDiversitySolver.py
70 lines (60 loc) · 2.56 KB
/
BaseDiversitySolver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from recordclass import recordclass
from math import floor
User = recordclass('User', 'id marg groups')
Group = recordclass('Group', 'id wei cov users')
Solution = recordclass('Solution', 'users groups')
class BaseDiversitySolver(object):
def __init__(self, users, groups):
self.solution = Solution([], [])
self.users = [User(user, None, []) for user in users]
self.groups = [Group(i, None, None, []) for i in range(len(groups))]
for user in self.users:
for group in self.groups:
if user.id in groups[group.id]:
group.users.append(user)
user.groups.append(group)
def solve(self, bucket_size, weight='LBS', cover='Single'):
if len(self.solution.users) > 0:
return self.solution
for group in self.groups:
group.wei, group.cov = self.weight(group, weight, bucket_size),\
self.cover(group, cover, bucket_size)
for user in self.users:
user.marg = BaseDiversitySolver.marg(user)
for i in range(bucket_size):
if len(self.users) == 0:
break
users_margs = [user.marg for user in self.users]
max_user_idx = users_margs.index(max(users_margs))
max_user = self.users[max_user_idx]
self.solution.users.append(max_user.id)
self.users.remove(max_user)
for group in self.groups:
if group.cov > 0 and max_user in group.users:
group.cov -= 1
self.solution.groups.append(group.id)
if group.cov == 0:
for user in group.users:
user.marg -= group.wei
return self.solution
def weight(self, group, type, bucket_size):
if type == 'Iden':
return 1
elif type == 'LBS':
return len(group.users)
elif type == 'EBS':
sorted_groups = sorted(self.groups, key=lambda group: len(group.users))
order = sorted_groups.index(group) + 1
return (bucket_size + 1)**order
else:
raise Exception('Unsupported weight function')
def cover(self, group, type, bucket_size):
if type == 'Single':
return 1
elif type == 'Prop':
return max(floor(bucket_size * (len(group.users)) / len(self.users)), 1)
else:
raise Exception('Unsupported cover function')
@staticmethod
def marg(user):
return sum([group.wei for group in user.groups])