Coverage for agent_model/Model.py: 98%
154 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
1import random
2from copy import deepcopy
3import datetime
4import numpy as np
5from .util import get_default_currency_data, load_data_file, merge_json, recursively_clear_lists
6from .agents import BaseAgent, PlantAgent, LampAgent, SunAgent, AtmosphereEqualizerAgent, ConcreteAgent
8DEFAULT_START_TIME = '1991-01-01T00:00:00'
9DEFAULT_TIME_UNIT = 'hours'
10DEFAULT_LOCATION = 'earth'
11FLOATING_POINT_PRECISION = 6
12DEFAULT_PRIORITIES = ["structures", "storage", "power_generation", "inhabitants",
13 "eclss", "plants"]
15class Model:
17 floating_point_precision = FLOATING_POINT_PRECISION
18 time_unit = DEFAULT_TIME_UNIT
20 def __init__(self, termination=None, location=None, priorities=None,
21 start_time=None, elapsed_time=None, step_num=None, seed=None,
22 is_terminated=None, termination_reason=None):
24 # Initialize model data fields
25 self.termination = [] if termination is None else termination
26 self.location = DEFAULT_LOCATION if location is None else location
27 self.priorities = DEFAULT_PRIORITIES if priorities is None else priorities
28 self.start_time = datetime.datetime.fromisoformat(DEFAULT_START_TIME if start_time is None else start_time)
29 self.elapsed_time = datetime.timedelta(seconds=0 if elapsed_time is None else elapsed_time)
30 self.step_num = 0 if step_num is None else step_num
31 self.seed = seed if seed is not None else random.getrandbits(32)
32 self.is_terminated = None if is_terminated is None else is_terminated
33 self.termination_reason = '' if termination_reason is None else termination_reason
34 self.agents = {}
35 self.currencies = {}
37 # NON-SERIALIZABLE
38 self.rng = None
39 self.scheduler = None
40 self.registered = False
41 self.records = {'time': [], 'step_num': []}
43 def add_agent(self, agent_id, agent):
44 if agent_id in self.agents:
45 raise ValueError(f'Agent names must be unique ({agent_id})')
46 self.agents[agent_id] = agent
48 def add_currency(self, currency_id, currency_data):
49 if currency_id in self.currencies:
50 raise ValueError(f'Currency and currency class names must be unique ({currency_id})')
51 self.currencies[currency_id] = currency_data
53 def register(self, record_initial_state=False):
54 self.rng = np.random.RandomState(self.seed)
55 self.scheduler = Scheduler(self)
56 if record_initial_state:
57 self.records['time'].append(self.time.isoformat())
58 self.records['step_num'].append(self.step_num)
59 for agent in self.agents.values():
60 agent.register(record_initial_state)
61 self.registered = True
63 @classmethod
64 def from_config(cls, agents={}, currencies={}, record_initial_state=None, **kwargs):
65 # Initialize an empty model
66 model = cls(**kwargs)
68 # Overwrite generic connections
69 replacements = {'habitat': None, 'greenhouse': None}
70 for agent_id in agents.keys():
71 if 'habitat' in agent_id:
72 replacements['habitat'] = agent_id
73 elif 'greenhouse' in agent_id:
74 replacements['greenhouse'] = agent_id
75 def replace_generic_connections(conns):
76 """Replace if available, otherwise remove connection"""
77 replaced = [replacements.get(c, c) for c in conns]
78 pruned = [c for c in replaced if c is not None and c in agents]
79 return pruned
81 # Merge user agents with default agents
82 default_agent_desc = load_data_file('agent_desc.json')
83 currencies_in_use = set()
84 for agent_id, agent_data in agents.items():
86 # Load default agent data and/or prototypes
87 prototypes = agent_data.pop('prototypes', [])
88 if agent_id in default_agent_desc:
89 prototypes.append(agent_id)
90 while len(prototypes) > 0:
91 [prototype, *prototypes] = prototypes
92 if prototype not in default_agent_desc:
93 raise ValueError(f'Agent prototype not found ({prototype})')
94 agent_data = merge_json(deepcopy(default_agent_desc[prototype]),
95 deepcopy(agent_data))
96 if 'prototypes' in agent_data:
97 prototypes += agent_data.pop('prototypes')
98 agent_data['agent_id'] = agent_id
100 if 'flows' in agent_data:
101 for flows in agent_data['flows'].values():
102 for currency, flow_data in flows.items():
103 # Record currencies in use
104 currencies_in_use.add(currency)
105 # Replace generic connections
106 flow_data['connections'] = replace_generic_connections(flow_data['connections'])
107 if 'storage' in agent_data:
108 for currency in agent_data['storage'].keys():
109 currencies_in_use.add(currency)
111 # Determine agent class. TODO: Remove hard-coding somehow?
112 if 'agent_class' in agent_data and agent_data['agent_class'] == 'plants':
113 build_from_class = PlantAgent
114 elif 'lamp' in agent_id:
115 build_from_class = LampAgent
116 elif 'sun' in agent_id:
117 build_from_class = SunAgent
118 elif 'atmosphere_equalizer' in agent_id:
119 build_from_class = AtmosphereEqualizerAgent
120 elif 'concrete' in agent_id:
121 build_from_class = ConcreteAgent
122 else:
123 build_from_class = BaseAgent
125 agent = build_from_class(model, **agent_data)
126 model.add_agent(agent_id, agent)
128 # Merge user currencies with default currencies
129 currencies = {**get_default_currency_data(), **currencies}
130 for currency_id, currency_data in currencies.items():
131 # Only add currencies and currency classes with active flows
132 if (currency_id in currencies_in_use or
133 (currency_data.get('currency_type') == 'class' and
134 any(c in currencies_in_use for c in currency_data['currencies']))):
135 model.add_currency(currency_id, currency_data)
137 if record_initial_state is None:
138 record_initial_state = model.step_num == 0
139 model.register(record_initial_state)
140 return model
142 @property
143 def time(self):
144 return self.start_time + self.elapsed_time
146 def step(self, dT=1):
147 """Advance the model by one step.
149 Args:
150 dT (int, optional): delta time in base time units. Defaults to 1.
151 """
152 if not self.registered:
153 self.register()
154 self.step_num += 1
155 self.elapsed_time += datetime.timedelta(**{self.time_unit: dT})
156 for term in self.termination:
157 if term['condition'] == 'time':
158 if term['unit'] in ('day', 'days'):
159 reference = self.elapsed_time.days
160 elif term['unit'] in ('hour', 'hours'):
161 reference = self.elapsed_time.total_seconds() // 3600
162 else:
163 raise ValueError(f'Invalid termination time unit: '
164 f'{term["unit"]}')
165 if reference >= term['value']:
166 self.is_terminated = True
167 self.termination_reason = 'time'
168 self.scheduler.step(dT)
169 self.records['time'].append(self.time.isoformat())
170 self.records['step_num'].append(self.step_num)
172 def run(self, dT=1, max_steps=365*24*2):
173 """Run the model until termination.
175 Args:
176 dT (int, optional): delta time in base time units. Defaults to 1.
177 max_steps (int, optional): maximum number of steps to run. Defaults to 365*24*2.
178 """
179 while not self.is_terminated and self.step_num < max_steps:
180 self.step(dT)
182 def get_records(self, static=False, clear_cache=False):
183 output = deepcopy(self.records)
184 output['agents'] = {name: agent.get_records(static, clear_cache)
185 for name, agent in self.agents.items()}
186 if static:
187 output['static'] = {
188 'currencies': self.currencies,
189 'termination': self.termination,
190 'location': self.location,
191 'priorities': self.priorities,
192 'start_time': self.start_time.isoformat(),
193 'seed': self.seed,
194 }
195 if clear_cache:
196 self.records = recursively_clear_lists(self.records)
197 return output
199 def save(self, records=False):
200 output = {
201 'agents': {name: agent.save(records) for name, agent in self.agents.items()},
202 'currencies': self.currencies,
203 'termination': self.termination,
204 'location': self.location,
205 'priorities': self.priorities,
206 'start_time': self.start_time.isoformat(),
207 'elapsed_time': self.elapsed_time.total_seconds(),
208 'step_num': self.step_num,
209 'seed': self.seed,
210 'is_terminated': self.is_terminated,
211 'termination_reason': self.termination_reason,
212 }
213 if records:
214 output['records'] = deepcopy(self.records)
215 return output
217class Scheduler:
218 def __init__(self, model):
219 self.model = model
220 self.priorities = [*model.priorities, 'other']
221 self.class_agents = {p: [] for p in self.priorities}
222 for agent, agent_data in model.agents.items():
223 if agent_data.agent_class in self.priorities:
224 self.class_agents[agent_data.agent_class].append(agent)
225 else:
226 self.class_agents['other'].append(agent)
228 def step(self, dT):
229 for agent_class in self.priorities:
230 queue = self.model.rng.permutation(self.class_agents[agent_class])
231 for agent in queue:
232 self.model.agents[agent].step(dT)