Coverage for agent_model/agents/base.py: 98%

219 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-05-04 13:14 +0700

1import math 

2from copy import deepcopy 

3from collections import defaultdict 

4 

5from ..util import evaluate_reference, evaluate_growth, recursively_clear_lists 

6 

7class BaseAgent: 

8 # ------------- SETUP ------------- # 

9 def __init__(self, model, agent_id, amount=1, description=None, 

10 agent_class=None, properties=None, capacity=None, 

11 thresholds=None, flows=None, cause_of_death=None, active=None, 

12 storage=None, attributes=None): 

13 """Create an agent with the given parameters. 

14  

15 Args: 

16 model (AgentModel): AgentModel instance 

17 agent_id (str): A unique string 

18 amount (int): Starting/Maximum number alive 

19 description (str): Plaintext description 

20 agent_class (str): Agent class name 

21 properties (dict): Static vars, 'volume' 

22 capacity (dict): Max storage per currency per individual 

23 thresholds (dict): Env. conditions to die 

24 flows (dict): Exchanges w/ other agents 

25 cause_of_death (str): Reason for death 

26 active (int): Current number alive 

27 storage (dict): Currencies stored by total amount 

28 attributes (dict): Dynamic vars, 'te_factor' 

29 """ 

30 # -- STATIC 

31 self.agent_id = agent_id 

32 self.amount = 1 if amount is None else amount 

33 self.description = '' if description is None else description 

34 self.agent_class = '' if agent_class is None else agent_class 

35 self.properties = {} if properties is None else deepcopy(properties) 

36 self.capacity = {} if capacity is None else deepcopy(capacity) 

37 self.thresholds = {} if thresholds is None else deepcopy(thresholds) 

38 self.flows = {'in': {}, 'out': {}} 

39 for direction in ('in', 'out'): 

40 if flows is not None and direction in flows: 

41 self.flows[direction] = deepcopy(flows[direction]) 

42 # -- DYNAMIC 

43 self.cause_of_death = cause_of_death 

44 self.active = amount if active is None else deepcopy(active) 

45 self.storage = {} if storage is None else deepcopy(storage) 

46 self.attributes = {} if attributes is None else deepcopy(attributes) 

47 # -- NON-SERIALIZED 

48 self.model = model 

49 self.registered = False 

50 self.records = {} 

51 

52 def register(self, record_initial_state=False): 

53 """Check and setup agent after all agents have been added to Model. 

54  

55 Args: 

56 record_initial_state (bool): Whether to include a value for  

57 'step 0'; True for new simulations, false when loading 

58 """ 

59 if self.registered: 

60 return 

61 if 'age' not in self.attributes: 

62 self.attributes['age'] = 0 

63 for currency in self.storage: 

64 if currency not in self.capacity: 

65 raise ValueError(f'Agent {self.agent_id} has storage for ' 

66 f'{currency} but no capacity.') 

67 elif self.storage[currency] > self.capacity[currency] * self.active: 

68 raise ValueError(f'Agent {self.agent_id} has more storage ' 

69 f'for {currency} than capacity.') 

70 # Initialize flow attributes and records, check connections 

71 flow_records = {'in': defaultdict(dict), 'out': defaultdict(dict)} 

72 for direction, flows in self.flows.items(): 

73 for currency, flow in flows.items(): 

74 self.register_flow(direction, currency, flow) 

75 for conn in flow['connections']: 

76 agent = self.model.agents[conn] 

77 for _currency in agent.view(currency): 

78 record = [] if not record_initial_state else [0] 

79 flow_records[direction][_currency][conn] = record 

80 

81 # Initialize records skeleton 

82 self.records = { 

83 'active': [] if not record_initial_state else [self.active], 

84 'cause_of_death': self.cause_of_death, 

85 'storage': {currency: [] if not record_initial_state 

86 else [self.storage.get(currency, 0)] 

87 for currency in self.capacity}, 

88 'attributes': {attr: [] if not record_initial_state 

89 else [self.attributes[attr]] 

90 for attr in self.attributes}, 

91 'flows': flow_records, 

92 } 

93 self.registered = True 

94 

95 def register_flow(self, direction, currency, flow): 

96 """Check flow, setup attributes and records. Overloadable by subclasses.""" 

97 # Check flow fields 

98 allowed_fields = {'value', 'flow_rate', 'criteria', 'connections', 

99 'deprive', 'weighted', 'requires', 'growth'} 

100 for field in flow: 

101 if field not in allowed_fields: 

102 raise ValueError(f'Flow field {field} not allowed') 

103 # Initialize attributes 

104 if 'criteria' in flow: 

105 for i, criterion in enumerate(flow['criteria']): 

106 if 'buffer' in criterion: 

107 buffer_attr = f'{direction}_{currency}_criteria_{i}_buffer' 

108 self.attributes[buffer_attr] = criterion['buffer'] 

109 if 'deprive' in flow: 

110 deprive_attr = f'{direction}_{currency}_deprive' 

111 self.attributes[deprive_attr] = flow['deprive']['value'] 

112 if 'growth' in flow: 

113 for mode, params in flow['growth'].items(): 

114 growth_attr = f'{direction}_{currency}_{mode}_growth_factor' 

115 self.attributes[growth_attr] = evaluate_growth(self, mode, params) 

116 # Check flow connections 

117 for agent in flow['connections']: 

118 if agent not in self.model.agents: 

119 raise ValueError(f'Agent {agent} not registered') 

120 currency_type = self.model.currencies[currency]['currency_type'] 

121 if currency_type == 'currency': 

122 if currency not in self.model.agents[agent].capacity: 

123 raise ValueError(f'Agent {agent} does not store {currency}') 

124 else: 

125 class_currencies = self.model.currencies[currency]['currencies'] 

126 if not any(c in self.model.agents[agent].capacity 

127 for c in class_currencies): 

128 raise ValueError(f'Agent {agent} does not store any ' 

129 f'currencies of class {currency}') 

130 

131 # ------------- INSPECT ------------- # 

132 def view(self, view): 

133 """Return a dict with storage amount for single currency or all of a class""" 

134 currency_type = self.model.currencies[view]['currency_type'] 

135 if currency_type == 'currency': 

136 if view not in self.storage: 

137 return {view: 0} 

138 return {view: self.storage[view]} 

139 elif currency_type == 'class': 

140 class_currencies = self.model.currencies[view]['currencies'] 

141 return {currency: self.storage.get(currency, 0) 

142 for currency in class_currencies 

143 if currency in self.capacity} 

144 

145 def serialize(self): 

146 """Return json-serializable dict of agent attributes""" 

147 serializable = {'agent_id', 'amount', 'description', 'agent_class', 

148 'properties', 'capacity', 'thresholds', 'flows', 

149 'cause_of_death', 'active', 'storage', 'attributes'} 

150 output = {k: deepcopy(getattr(self, k)) for k in serializable} 

151 return output 

152 

153 def get_records(self, static=False, clear_cache=False): 

154 """Return records dict and optionally clear cache""" 

155 output = deepcopy(self.records) 

156 if static: 

157 static_records = self.serialize() 

158 non_static = ('cause_of_death', 'active', 'storage', 'attributes') 

159 for k in non_static: 

160 del static_records[k] 

161 output['static'] = static_records 

162 if clear_cache: 

163 self.records = recursively_clear_lists(self.records) 

164 return output 

165 

166 def save(self, records=False): 

167 """Return a serializable copy of the agent""" 

168 output = self.serialize() 

169 if records: 

170 output['records'] = self.get_records() 

171 return output 

172 

173 # ------------- UPDATE ------------- # 

174 def increment(self, currency, value): 

175 """Increment currency in storage as available, return actual receipt""" 

176 if value == 0: # If currency_class, return dict of currencies 

177 available = self.view(currency) 

178 return {k: 0 for k in available.keys()} 

179 elif value < 0: # Can be currency or currency_class 

180 available = self.view(currency) 

181 total_available = sum(available.values()) 

182 if total_available == 0: 

183 return available 

184 actual = -min(-value, total_available) 

185 increment = {currency: actual * stored/total_available 

186 for currency, stored in available.items()} 

187 for _currency, amount in increment.items(): 

188 if amount != 0: 

189 self.storage[_currency] += amount 

190 return increment 

191 elif value > 0: # Can only be currency 

192 if self.model.currencies[currency]['currency_type'] != 'currency': 

193 raise ValueError(f'Cannot increment agent by currency class ({currency})') 

194 if currency not in self.capacity: 

195 raise ValueError(f'Agent does not store {currency}') 

196 if currency not in self.storage: 

197 self.storage[currency] = 0 

198 total_capacity = self.capacity[currency] * self.active 

199 remaining_capacity = total_capacity - self.storage[currency] 

200 actual = min(value, remaining_capacity) 

201 self.storage[currency] += actual 

202 return {currency: actual} 

203 

204 def get_flow_value(self, dT, direction, currency, flow, influx): 

205 """Update flow state pre-exchange and return target value.  

206  

207 Overloadable by subclasses.""" 

208 # Baseline 

209 step_value = flow['value'] * dT 

210 # Adjust 

211 requires = flow.get('requires') 

212 if step_value > 0 and requires: 

213 if any(_currency not in influx for _currency in requires): 

214 step_value = 0 

215 else: 

216 for _currency in requires: 

217 step_value *= influx[_currency] # Scale flows to requires 

218 criteria = flow.get('criteria') 

219 if step_value > 0 and criteria: 

220 for i, criterion in enumerate(criteria): 

221 buffer_attr = f'{direction}_{currency}_criteria_{i}_buffer' 

222 if evaluate_reference(self, criterion): 

223 if 'buffer' in criterion and self.attributes[buffer_attr] > 0: 

224 self.attributes[buffer_attr] -= dT 

225 step_value = 0 

226 else: 

227 step_value = 0 

228 if 'buffer' in criterion and self.attributes[buffer_attr] == 0: 

229 self.attributes[buffer_attr] = criterion['buffer'] 

230 growth = flow.get('growth') 

231 if step_value > 0 and growth: 

232 for mode, params in growth.items(): 

233 growth_attr = f'{direction}_{currency}_{mode}_growth_factor' 

234 growth_factor = evaluate_growth(self, mode, params) 

235 self.attributes[growth_attr] = growth_factor 

236 step_value *= growth_factor 

237 weighted = flow.get('weighted') 

238 if step_value > 0 and weighted: 

239 for field in weighted: 

240 if field in self.capacity: # e.g. Biomass 

241 weight = self.view(field)[field] / self.active 

242 elif field in self.properties: # e.g. 'mass' 

243 weight = self.properties[field]['value'] 

244 elif field in self.attributes: # e.g. 'te_factor' 

245 weight = self.attributes[field] 

246 else: 

247 raise ValueError(f'Weighted field {field} not found in ' 

248 f'{self.agent_id} storage, properties, or attributes.') 

249 step_value *= weight 

250 return step_value 

251 

252 def process_flow(self, dT, direction, currency, flow, influx, target, actual): 

253 """Update flow state post-exchange. Overloadable by subclasses.""" 

254 available_ratio = round(0 if target == 0 else actual/target, 

255 self.model.floating_point_precision) 

256 if direction == 'in': 

257 influx[currency] = available_ratio 

258 if 'deprive' in flow: 

259 deprive_attr = f'{direction}_{currency}_deprive' 

260 if available_ratio < 1: 

261 deprived_ratio = 1 - available_ratio 

262 remaining = self.attributes[deprive_attr] - (deprived_ratio * dT) 

263 self.attributes[deprive_attr] = max(0, remaining) 

264 if remaining < 0: 

265 n_dead = math.ceil(-remaining * self.active) 

266 self.kill(f'{self.agent_id} deprived of {currency}', n_dead=n_dead) 

267 else: 

268 self.attributes[deprive_attr] = flow['deprive']['value'] 

269 

270 

271 def step(self, dT=1): 

272 """Update agent for given timedelta.""" 

273 if not self.registered: 

274 self.register() 

275 if self.active: 

276 self.attributes['age'] += dT 

277 # Check thresholds 

278 for currency, threshold in self.thresholds.items(): 

279 if evaluate_reference(self, threshold): 

280 self.kill(f'{self.agent_id} passed {currency} threshold') 

281 

282 # Execute flows 

283 influx = {} # Which currencies were consumed, and what fraction of baseline 

284 for direction in ['in', 'out']: 

285 for currency, flow in self.flows[direction].items(): 

286 

287 # Calculate Target Value 

288 if self.active and 'value' in flow: 

289 target = self.active * self.get_flow_value(dT, direction, currency, flow, influx) 

290 else: 

291 target = 0 

292 

293 # Process Flow 

294 remaining = float(target) 

295 for connection in flow['connections']: 

296 agent = self.model.agents[connection] 

297 if remaining > 0: 

298 multiplier = {'in': -1, 'out': 1}[direction] 

299 exchange = agent.increment(currency, multiplier * remaining) 

300 exchange_value = sum(exchange.values()) 

301 remaining -= abs(exchange_value) 

302 else: 

303 exchange = {k: 0 for k in agent.view(currency).keys()} 

304 # NOTE: This must be called regardless of whether the agent is active 

305 for _currency, _value in exchange.items(): 

306 self.records['flows'][direction][_currency][connection].append(abs(_value)) 

307 actual = target - remaining 

308 # TODO: Handle excess outputs; currently ignored 

309 

310 # Respond to availability 

311 if self.active and 'value' in flow: 

312 self.process_flow(dT, direction, currency, flow, influx, target, actual) 

313 

314 # Update remaining records 

315 self.records['active'].append(self.active) 

316 for currency in self.capacity: 

317 self.records['storage'][currency].append(self.storage.get(currency, 0)) 

318 for attribute in self.attributes: 

319 self.records['attributes'][attribute].append(self.attributes[attribute]) 

320 self.records['cause_of_death'] = self.cause_of_death 

321 

322 def kill(self, reason, n_dead=None): 

323 """Kill n_dead agents, or all if n_dead is None. Overloadable by subclasses.""" 

324 n_dead = self.active if n_dead is None else n_dead 

325 self.active = max(0, self.active - n_dead) 

326 if self.active <= 0: 

327 self.cause_of_death = reason