Coverage for agent_model/util.py: 94%
157 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
1import json, copy, operator, math
2from pathlib import Path
3import numpy as np
5# DATA HANDLING
6def load_data_file(fname, data_dir=None):
7 """Load data file from data directory."""
8 if data_dir is None:
9 # Get the absolute path of the directory containing the current script
10 script_dir = Path(__file__).resolve().parent.parent
11 data_dir = script_dir / 'data_files'
12 else:
13 data_dir = Path(data_dir)
14 assert data_dir.exists(), f'Data directory does not exist: {data_dir}'
15 data_file = data_dir / fname
16 assert data_file.exists(), f'Data file does not exist: {data_file}'
17 with open(data_file, 'r') as f:
18 data = json.load(f)
19 return data
21def get_default_agent_data(agent):
22 """Return the relevant dict from default agent_desc.json"""
23 default_agent_desc = load_data_file('agent_desc.json')
24 if agent in default_agent_desc:
25 return copy.deepcopy(default_agent_desc[agent])
26 return None
28def get_default_currency_data():
29 """Load default currency_desc.json and convert to new structure"""
30 currencies = {}
31 currency_desc = load_data_file('currency_desc.json')
32 for currency_class, class_currencies in currency_desc.items():
33 for currency, currency_data in class_currencies.items():
34 currencies[currency] = currency_data
35 currencies[currency]['currency_type'] = 'currency'
36 currencies[currency]['class'] = currency_class
37 currencies[currency_class] = {'currency_type': 'class',
38 'currencies': list(class_currencies.keys())}
39 return currencies
41def merge_json(default, to_merge):
42 """Merge two objects of arbitrary depth/elements"""
43 if isinstance(to_merge, dict):
44 for k, v in to_merge.items():
45 default[k] = v if k not in default else merge_json(default[k], v)
46 return default
47 elif isinstance(to_merge, list):
48 if len(to_merge) == 0:
49 return default
50 elif isinstance(to_merge[0], dict):
51 return to_merge
52 elif isinstance(to_merge[0], (str, int, float)):
53 return list(set(default).union(set(to_merge)))
54 else:
55 raise ValueError(f'Cannot merge lists of type {type(to_merge[0])}')
56 elif isinstance(to_merge, (str, int, float, bool)):
57 return to_merge
59def recursively_clear_lists(r):
60 if isinstance(r, (int, float, str)):
61 return r
62 elif isinstance(r, dict):
63 return {k: recursively_clear_lists(v) for k, v in r.items()}
64 elif isinstance(r, list):
65 return []
67def recursively_check_required_kwargs(given, required):
68 for key, value in required.items():
69 if key not in given:
70 raise ValueError(f'{key} not found in {given}')
71 if isinstance(value, dict):
72 recursively_check_required_kwargs(given[key], value)
74# LIMIT FUNCTIONS (THRESHOLD AND CRITERIA)
75operator_dict = {
76 '>': operator.gt, '<': operator.lt,
77 '>=': operator.ge, '<=': operator.le,
78 '=': operator.eq, '!=': operator.ne,
79}
80def evaluate_reference(agent, reference):
81 """Evaluate a reference dict and return a boolean
83 Supported path elements:
84 - 'grown': from attributes
85 - 'in_co2_ratio': ratio (0-1) of co2 to total class (atmosphere) from first connected agent
86 """
87 path, limit, value = reference['path'], reference['limit'], reference['value']
88 ref_agent = agent
89 # Parse connected agent
90 if path.startswith('in_') or path.startswith('out_'):
91 # Evaluate connections by direction/currency
92 direction, remainder = path.split('_', 1)
93 if path.endswith('_ratio'):
94 currency = '_'.join(remainder.split('_')[:-1])
95 else:
96 currency = remainder
97 conns = agent.flows[direction][currency]['connections']
98 updated_reference = {**reference, 'path': remainder}
99 results = (evaluate_reference(agent.model.agents[c], updated_reference) for c in conns)
100 # Return group eval connections
101 if 'connections' in reference and reference['connections'] == 'all':
102 return all(results)
103 return any(results)
104 # Parse field
105 if path in ref_agent.attributes:
106 target = ref_agent.attributes[path]
107 elif path in ref_agent.storage:
108 target = ref_agent.storage[path]
109 elif path.endswith('_ratio'):
110 currency = path[:-6]
111 currency_data = ref_agent.model.currencies[currency]
112 total = sum(ref_agent.view(currency_data['class']).values())
113 target = 0 if not total else ref_agent.storage[currency] / total
114 # Evaluate
115 return operator_dict[limit](
116 round(target, agent.model.floating_point_precision),
117 round(value, agent.model.floating_point_precision))
119# GROWTH FUNCTIONS
121def pdf(_x, std, cache={}):
122 """Return y-value of normal distribution at x-value for mean=0"""
123 if (_x, std) not in cache:
124 numerator = math.exp(-1 * (_x ** 2) / (2 * (std ** 2)))
125 denominator = math.sqrt(2 * math.pi) * std
126 cache[(_x, std)] = numerator / denominator
127 return cache[(_x, std)]
129def pdf_mean(std, center, n_samples, cache={}):
130 """Calculate the mean y-value of the pdf"""
131 if (std, center) not in cache:
132 x_vals = [i/n_samples for i in range(n_samples)]
133 y_vals = [pdf((x - center) / std, std) for x in x_vals]
134 cache[(std, center)] = sum(y_vals) / n_samples
135 return cache[(std, center)]
137def sample_norm(rate, std=math.pi/10, center=0.5, n_samples=100):
138 """Return y-value of normal distribution at x-value, such mean(y) = 1
140 Arguments:
141 rate: x-value to sample at
142 std: standard deviation of normal distribution
143 center: x-value to center the distribution at
144 n_samples: number of samples to use for mean calculation
145 """
146 if any(v < 0 or v > 1 for v in (rate, std, center)):
147 raise ValueError('rate, std, and center must be between 0 and 1.')
148 # Shift x-value to center at 0
149 x = (rate - center) / std
150 # Calculate y-value
151 y = pdf(x, std)
152 # Normalize y-value to mean of 1
153 y_mean = pdf_mean(std, center, n_samples)
154 return y / y_mean
156def sample_clipped_norm(rate, factor=2, **kwargs):
157 """Return y-value of normal distribution at x-value, clipped at center
159 From sample_norm, multiply all values by factor, clip at original max,
160 then scale to max=1.
162 Arguments:
163 rate: x-value to sample at
164 factor: factor to multiply the normal distribution by
165 """
166 norm_value = sample_norm(rate, **kwargs) # Get the norm value
167 center = kwargs.get('center', 0.5)
168 y_max = sample_norm(center, **kwargs) # Get max value for that curve
169 norm_value *= factor # Scale value by factor
170 clip_value = min(norm_value, y_max) # Clip at original max
171 return clip_value / y_max # Scale to max=1
173def sample_sigmoid(rate, min_value=0, max_value=1, steepness=1, center=0.5):
174 """return the sigmoid value"""
175 x = steepness * 20 * (rate - center)
176 y = 1 / (1 + np.exp(-x))
177 scaled = y * (max_value - min_value)
178 shifted = scaled + min_value
179 return shifted
181def sample_switch(rate, min_value=0, max_value=1, center=0.5, duration=0.5):
182 """return the switch value"""
183 if rate > center - duration / 2 and rate < center + duration / 2:
184 return max_value
185 return min_value
187def evaluate_growth(agent, mode, params):
188 if mode == 'daily':
189 rate = agent.model.time.hour / 24
190 elif mode == 'lifetime':
191 rate = agent.attributes['age'] / agent.properties['lifetime']['value']
192 growth_func = {
193 'norm': sample_norm,
194 'sigmoid': sample_sigmoid,
195 'clipped': sample_clipped_norm,
196 'switch': sample_switch
197 }[params['type']]
198 return growth_func(rate, **{k: v for k, v in params.items() if k != 'type'})
200# WORKING WITH OUTPUTS
202def parse_data(data, path):
203 """Recursive function to extract data at path from arbitrary object"""
204 if not data and data != 0:
205 return None
206 elif len(path) == 0:
207 return 0 if data is None else data
208 # Shift the first element of path, past on the rest of the path
209 index, *remainder = path
210 # LISTS
211 if isinstance(data, list):
212 # All Items
213 if index == '*':
214 parsed = [parse_data(d, remainder) for d in data]
215 return [d for d in parsed if d is not None]
216 # Single index
217 elif isinstance(index, int):
218 return parse_data(data[index], remainder)
219 # Range i:j (string)
220 else:
221 start, end = [int(i) for i in index.split(':')]
222 return [parse_data(d, remainder) for d in data[start:end]]
223 # DICTS
224 elif isinstance(data, dict):
225 # All items, either a dict ('*') or a number ('SUM')
226 if index in {'*', 'SUM'}:
227 parsed = [parse_data(d, remainder) for d in data.values()]
228 output = {k: v for k, v in zip(data.keys(), parsed) if v or v == 0}
229 if len(output) == 0:
230 return None
231 elif index == '*':
232 return output
233 else:
234 if isinstance(next(iter(output.values())), list):
235 return [sum(x) for x in zip(*output.values())]
236 else:
237 return sum(output.values())
238 # Single Key
239 elif index in data:
240 return parse_data(data[index], remainder)
241 # Comma-separated list of keys. Return an object with all.
242 elif isinstance(index, str):
243 indices = [i.strip() for i in index.split(',') if i in data]
244 parsed = [parse_data(data[i], remainder) for i in indices]
245 output = {k: v for k, v in zip(indices, parsed) if v or v == 0}
246 return output if len(output) > 0 else None