Coverage for agent_model/Model.py: 98%

154 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-05-04 13:14 +0700

1import random 

2from copy import deepcopy 

3import datetime 

4import numpy as np 

5from .util import get_default_currency_data, load_data_file, merge_json, recursively_clear_lists 

6from .agents import BaseAgent, PlantAgent, LampAgent, SunAgent, AtmosphereEqualizerAgent, ConcreteAgent 

7 

8DEFAULT_START_TIME = '1991-01-01T00:00:00' 

9DEFAULT_TIME_UNIT = 'hours' 

10DEFAULT_LOCATION = 'earth' 

11FLOATING_POINT_PRECISION = 6 

12DEFAULT_PRIORITIES = ["structures", "storage", "power_generation", "inhabitants", 

13 "eclss", "plants"] 

14 

15class Model: 

16 

17 floating_point_precision = FLOATING_POINT_PRECISION 

18 time_unit = DEFAULT_TIME_UNIT 

19 

20 def __init__(self, termination=None, location=None, priorities=None, 

21 start_time=None, elapsed_time=None, step_num=None, seed=None, 

22 is_terminated=None, termination_reason=None): 

23 

24 # Initialize model data fields 

25 self.termination = [] if termination is None else termination 

26 self.location = DEFAULT_LOCATION if location is None else location 

27 self.priorities = DEFAULT_PRIORITIES if priorities is None else priorities 

28 self.start_time = datetime.datetime.fromisoformat(DEFAULT_START_TIME if start_time is None else start_time) 

29 self.elapsed_time = datetime.timedelta(seconds=0 if elapsed_time is None else elapsed_time) 

30 self.step_num = 0 if step_num is None else step_num 

31 self.seed = seed if seed is not None else random.getrandbits(32) 

32 self.is_terminated = None if is_terminated is None else is_terminated 

33 self.termination_reason = '' if termination_reason is None else termination_reason 

34 self.agents = {} 

35 self.currencies = {} 

36 

37 # NON-SERIALIZABLE 

38 self.rng = None 

39 self.scheduler = None 

40 self.registered = False 

41 self.records = {'time': [], 'step_num': []} 

42 

43 def add_agent(self, agent_id, agent): 

44 if agent_id in self.agents: 

45 raise ValueError(f'Agent names must be unique ({agent_id})') 

46 self.agents[agent_id] = agent 

47 

48 def add_currency(self, currency_id, currency_data): 

49 if currency_id in self.currencies: 

50 raise ValueError(f'Currency and currency class names must be unique ({currency_id})') 

51 self.currencies[currency_id] = currency_data 

52 

53 def register(self, record_initial_state=False): 

54 self.rng = np.random.RandomState(self.seed) 

55 self.scheduler = Scheduler(self) 

56 if record_initial_state: 

57 self.records['time'].append(self.time.isoformat()) 

58 self.records['step_num'].append(self.step_num) 

59 for agent in self.agents.values(): 

60 agent.register(record_initial_state) 

61 self.registered = True 

62 

63 @classmethod 

64 def from_config(cls, agents={}, currencies={}, record_initial_state=None, **kwargs): 

65 # Initialize an empty model 

66 model = cls(**kwargs) 

67 

68 # Overwrite generic connections 

69 replacements = {'habitat': None, 'greenhouse': None} 

70 for agent_id in agents.keys(): 

71 if 'habitat' in agent_id: 

72 replacements['habitat'] = agent_id 

73 elif 'greenhouse' in agent_id: 

74 replacements['greenhouse'] = agent_id 

75 def replace_generic_connections(conns): 

76 """Replace if available, otherwise remove connection""" 

77 replaced = [replacements.get(c, c) for c in conns] 

78 pruned = [c for c in replaced if c is not None and c in agents] 

79 return pruned 

80 

81 # Merge user agents with default agents 

82 default_agent_desc = load_data_file('agent_desc.json') 

83 currencies_in_use = set() 

84 for agent_id, agent_data in agents.items(): 

85 

86 # Load default agent data and/or prototypes 

87 prototypes = agent_data.pop('prototypes', []) 

88 if agent_id in default_agent_desc: 

89 prototypes.append(agent_id) 

90 while len(prototypes) > 0: 

91 [prototype, *prototypes] = prototypes 

92 if prototype not in default_agent_desc: 

93 raise ValueError(f'Agent prototype not found ({prototype})') 

94 agent_data = merge_json(deepcopy(default_agent_desc[prototype]), 

95 deepcopy(agent_data)) 

96 if 'prototypes' in agent_data: 

97 prototypes += agent_data.pop('prototypes') 

98 agent_data['agent_id'] = agent_id 

99 

100 if 'flows' in agent_data: 

101 for flows in agent_data['flows'].values(): 

102 for currency, flow_data in flows.items(): 

103 # Record currencies in use 

104 currencies_in_use.add(currency) 

105 # Replace generic connections 

106 flow_data['connections'] = replace_generic_connections(flow_data['connections']) 

107 if 'storage' in agent_data: 

108 for currency in agent_data['storage'].keys(): 

109 currencies_in_use.add(currency) 

110 

111 # Determine agent class. TODO: Remove hard-coding somehow? 

112 if 'agent_class' in agent_data and agent_data['agent_class'] == 'plants': 

113 build_from_class = PlantAgent 

114 elif 'lamp' in agent_id: 

115 build_from_class = LampAgent 

116 elif 'sun' in agent_id: 

117 build_from_class = SunAgent 

118 elif 'atmosphere_equalizer' in agent_id: 

119 build_from_class = AtmosphereEqualizerAgent 

120 elif 'concrete' in agent_id: 

121 build_from_class = ConcreteAgent 

122 else: 

123 build_from_class = BaseAgent 

124 

125 agent = build_from_class(model, **agent_data) 

126 model.add_agent(agent_id, agent) 

127 

128 # Merge user currencies with default currencies 

129 currencies = {**get_default_currency_data(), **currencies} 

130 for currency_id, currency_data in currencies.items(): 

131 # Only add currencies and currency classes with active flows 

132 if (currency_id in currencies_in_use or 

133 (currency_data.get('currency_type') == 'class' and 

134 any(c in currencies_in_use for c in currency_data['currencies']))): 

135 model.add_currency(currency_id, currency_data) 

136 

137 if record_initial_state is None: 

138 record_initial_state = model.step_num == 0 

139 model.register(record_initial_state) 

140 return model 

141 

142 @property 

143 def time(self): 

144 return self.start_time + self.elapsed_time 

145 

146 def step(self, dT=1): 

147 """Advance the model by one step. 

148  

149 Args: 

150 dT (int, optional): delta time in base time units. Defaults to 1. 

151 """ 

152 if not self.registered: 

153 self.register() 

154 self.step_num += 1 

155 self.elapsed_time += datetime.timedelta(**{self.time_unit: dT}) 

156 for term in self.termination: 

157 if term['condition'] == 'time': 

158 if term['unit'] in ('day', 'days'): 

159 reference = self.elapsed_time.days 

160 elif term['unit'] in ('hour', 'hours'): 

161 reference = self.elapsed_time.total_seconds() // 3600 

162 else: 

163 raise ValueError(f'Invalid termination time unit: ' 

164 f'{term["unit"]}') 

165 if reference >= term['value']: 

166 self.is_terminated = True 

167 self.termination_reason = 'time' 

168 self.scheduler.step(dT) 

169 self.records['time'].append(self.time.isoformat()) 

170 self.records['step_num'].append(self.step_num) 

171 

172 def run(self, dT=1, max_steps=365*24*2): 

173 """Run the model until termination. 

174  

175 Args: 

176 dT (int, optional): delta time in base time units. Defaults to 1. 

177 max_steps (int, optional): maximum number of steps to run. Defaults to 365*24*2. 

178 """ 

179 while not self.is_terminated and self.step_num < max_steps: 

180 self.step(dT) 

181 

182 def get_records(self, static=False, clear_cache=False): 

183 output = deepcopy(self.records) 

184 output['agents'] = {name: agent.get_records(static, clear_cache) 

185 for name, agent in self.agents.items()} 

186 if static: 

187 output['static'] = { 

188 'currencies': self.currencies, 

189 'termination': self.termination, 

190 'location': self.location, 

191 'priorities': self.priorities, 

192 'start_time': self.start_time.isoformat(), 

193 'seed': self.seed, 

194 } 

195 if clear_cache: 

196 self.records = recursively_clear_lists(self.records) 

197 return output 

198 

199 def save(self, records=False): 

200 output = { 

201 'agents': {name: agent.save(records) for name, agent in self.agents.items()}, 

202 'currencies': self.currencies, 

203 'termination': self.termination, 

204 'location': self.location, 

205 'priorities': self.priorities, 

206 'start_time': self.start_time.isoformat(), 

207 'elapsed_time': self.elapsed_time.total_seconds(), 

208 'step_num': self.step_num, 

209 'seed': self.seed, 

210 'is_terminated': self.is_terminated, 

211 'termination_reason': self.termination_reason, 

212 } 

213 if records: 

214 output['records'] = deepcopy(self.records) 

215 return output 

216 

217class Scheduler: 

218 def __init__(self, model): 

219 self.model = model 

220 self.priorities = [*model.priorities, 'other'] 

221 self.class_agents = {p: [] for p in self.priorities} 

222 for agent, agent_data in model.agents.items(): 

223 if agent_data.agent_class in self.priorities: 

224 self.class_agents[agent_data.agent_class].append(agent) 

225 else: 

226 self.class_agents['other'].append(agent) 

227 

228 def step(self, dT): 

229 for agent_class in self.priorities: 

230 queue = self.model.rng.permutation(self.class_agents[agent_class]) 

231 for agent in queue: 

232 self.model.agents[agent].step(dT)