Coverage for test/test_util.py: 99%
179 statements
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
« prev ^ index » next coverage.py v7.2.3, created at 2023-05-04 13:14 +0700
1import pytest
2import datetime
3from copy import deepcopy
5from ..agent_model.util import (load_data_file,
6 get_default_agent_data,
7 get_default_currency_data,
8 merge_json,
9 recursively_clear_lists,
10 evaluate_reference,
11 pdf,
12 sample_norm,
13 sample_clipped_norm,
14 sample_sigmoid,
15 evaluate_growth,
16 parse_data)
18class TestDataFilesHandling:
19 def test_load_data_files(self):
20 agent_desc = load_data_file('agent_desc.json')
21 assert 'wheat' in agent_desc, 'Failed to load agent_desc'
22 with pytest.raises(AssertionError):
23 load_data_file('nonexistent_file.json')
24 with pytest.raises(AssertionError):
25 load_data_file('agent_desc.json', data_dir='nonexistent_dir')
27 def test_get_default_agent_data(self):
28 wheat_data = get_default_agent_data('wheat')
29 assert all([k in wheat_data for k in ['amount', 'storage', 'properties', 'storage', 'flows']])
31 def test_get_default_currency_data(self):
32 currency_data = get_default_currency_data()
33 for k, v in currency_data.items():
34 assert 'currency_type' in v
35 if v['currency_type'] == 'class':
36 assert 'currencies' in v
37 else:
38 assert 'class' in v
39 assert v['class'] in currency_data
41 def test_merge_json(self):
42 default = {'a': 'red', 'b': 2, 'c': {'d': 3, 'e': 4}, 'f': [1, 2, 3]}
43 to_merge = {'a': 'blue', 'c': {'d': 6}, 'f': [3, 4, 5]}
44 merged = merge_json(default, to_merge)
45 assert merged == {'a': 'blue', 'b': 2, 'c': {'d': 6, 'e': 4}, 'f': [1, 2, 3, 4, 5]}
47 def test_recursively_clear_lists(self):
48 data = {
49 'a': 'string',
50 'b': 1,
51 'c': 2.1,
52 'd': ['e', 'f'],
53 'g': {'h': 'string2',
54 'i': ['j', 'k']}}
55 data = recursively_clear_lists(data)
56 assert data == {
57 'a': 'string',
58 'b': 1,
59 'c': 2.1,
60 'd': [],
61 'g': {'h': 'string2',
62 'i': []}}
64class MockAgent:
65 def __init__(self, model):
66 self.model = model
67 self.attributes = {'test_attribute': 1}
68 self.storage = {'test_currency_1': 1, 'test_currency_2': 2}
69 self.flows = {
70 'in': {
71 'test_currency_1': {
72 'value': 1,
73 'connections': ['test_agent_2']
74 }
75 },
76 'out': {
77 'test_currency_1': {
78 'value': 1,
79 'connections': ['test_agent_2']
80 }
81 },
82 }
83 def view(self, view):
84 if view in ('test_currency_1', 'test_currency_2'):
85 return {view: self.storage[view]}
86 elif view == 'test_currency_class':
87 return deepcopy(self.storage)
89class MockModel:
90 floating_point_precision = 6
91 agents = {}
92 currencies = {
93 'test_currency_1': {
94 'currency_type': 'currency',
95 'class': 'test_currency_class'
96 },
97 'test_currency_2': {
98 'currency_type': 'currency',
99 'class': 'test_currency_class'
100 },
101 'test_currency_class': {
102 'currency_type': 'class',
103 'currencies': ['test_currency_1', 'test_currency_2']
104 }
105 }
107@pytest.fixture(scope='function')
108def test_model():
109 model = MockModel()
110 test_agent_1 = MockAgent(model)
111 test_agent_2 = MockAgent(model)
112 test_agent_1.model = model
113 test_agent_2.model = model
114 model.agents = {
115 'test_agent_1': test_agent_1,
116 'test_agent_2': test_agent_2,
117 }
118 return model
120class TestEvaluateReference:
121 def test_evaluate_reference_attribute(self, test_model):
122 reference = {
123 'path': 'test_attribute',
124 'limit': '>',
125 'value': 1
126 }
127 test_agent_1 = test_model.agents['test_agent_1']
128 assert not evaluate_reference(test_agent_1, reference)
129 test_agent_1.attributes['test_attribute'] = 2
130 assert evaluate_reference(test_agent_1, reference)
132 def test_evaluate_reference_storage(self, test_model):
133 reference = {
134 'path': 'test_currency_1',
135 'limit': '>',
136 'value': 1
137 }
138 test_agent_1 = test_model.agents['test_agent_1']
139 assert not evaluate_reference(test_agent_1, reference)
140 test_agent_1.storage['test_currency_1'] = 2
141 assert evaluate_reference(test_agent_1, reference)
143 def test_evaluate_reference_ratio(self, test_model):
144 reference = {
145 'path': 'test_currency_1_ratio',
146 'limit': '>',
147 'value': 0.5
148 }
149 test_agent_1 = test_model.agents['test_agent_1']
150 assert not evaluate_reference(test_agent_1, reference)
151 reference['limit'] = '<'
152 assert evaluate_reference(test_agent_1, reference)
154 def test_evaluate_reference_connected(self, test_model):
155 reference = {
156 'path': 'in_test_currency_1',
157 'limit': '>',
158 'value': 1
159 }
160 test_agent_1 = test_model.agents['test_agent_1']
161 assert not evaluate_reference(test_agent_1, reference)
162 test_agent_2 = test_model.agents['test_agent_2']
163 test_agent_2.storage['test_currency_1'] = 2
164 assert evaluate_reference(test_agent_1, reference)
166class TestGrowthFuncs:
167 def test_growth_pdf(self):
168 _cache = {}
169 results = [pdf(x, 0.5, _cache) for x in range(-4, 5)]
170 # middle value should be highest, symmetrical either side
171 assert results[4] == max(results)
172 for i in range(4):
173 assert results[i] == results[-i-1]
174 assert list(_cache.values()) == results
176 def test_growth_sample_norm(self):
177 # Default: 0 < y < 1, x_center = 0.5
178 n_samples = 100
179 results = [sample_norm(x/100, n_samples=n_samples) for x in range(1, n_samples)]
180 assert sum(results)/len(results) == pytest.approx(1, abs=0.02)
181 # middle value should be highest, symmetrical either side
182 midpoint = n_samples//2-1
183 assert results[midpoint] == max(results)
184 for i in range(midpoint):
185 assert results[i] == pytest.approx(results[-i-1])
187 # Shift center
188 x_center = 0.25
189 results = [sample_norm(x/1000, center=x_center) for x in range(1000)]
190 assert sum(results)/len(results) == pytest.approx(1, abs=0.01)
191 assert results[250] == max(results)
193 # TODO: Shift stdev
195 def test_growth_sample_clippped_norm(self):
196 results = [sample_clipped_norm(x/10) for x in range(1, 10)]
197 assert max(results) == 1
198 assert results[4] == max(results)
199 for i in range(4):
200 assert results[i] == pytest.approx(results[-i-1])
202 def test_growth_sample_sigmoid(self):
203 results = [sample_sigmoid(x/1000) for x in range(1000)]
204 assert results[-1] == max(results)
205 assert all(results[i] <= results[i+1] for i in range(len(results)-1))
206 # Derivative (slope) is greatest at center
207 derivatives = [results[i+1] - results[i] for i in range(len(results)-1)]
208 for i in range(500):
209 assert derivatives[i] <= derivatives[i+1]
210 assert derivatives[500] == max(derivatives)
211 for i in range(500, 998):
212 assert derivatives[i] >= derivatives[i+1]
214@pytest.fixture
215def mock_agent():
216 class MockAgent:
217 def __init__(self, model):
218 self.model = model
219 self.attributes = {'age': 10}
220 self.properties = {'lifetime': {'value': 20}}
221 class MockModel:
222 time = datetime.datetime(2019, 1, 1, 12)
223 return MockAgent(MockModel())
225class TestEvaluateGrowth:
226 def test_evaluate_growth_daily(self, mock_agent):
227 mode = 'daily'
228 params = {'type': 'norm'}
229 daily_vals = []
230 for hour in range(24):
231 mock_agent.model.time = datetime.datetime(2019, 1, 1, hour)
232 daily_vals.append(evaluate_growth(mock_agent, mode, params))
233 # Max growth (1) at noon
234 assert daily_vals[12] == max(daily_vals)
235 # Min growth (nearly 0) at midnight
236 assert daily_vals[0] == min(daily_vals)
238 def test_evaluate_growth_lifetime(self, mock_agent):
239 mode = 'lifetime'
240 params = {'type': 'sigmoid'}
241 # Halfway growth at age 10/20
242 assert evaluate_growth(mock_agent, mode, params) == 0.5
243 # Max growth (nearly 1) at age 20/20
244 mock_agent.attributes['age'] = 20
245 assert 0.999 < evaluate_growth(mock_agent, mode, params) < 1.0
247@pytest.fixture
248def mock_data():
249 return {
250 'model_string_attribute': 'test',
251 'model_int_attribute': 1,
252 'test_agent': {
253 'agent_string_attribute': 'test',
254 'agent_int_attribute': 1,
255 'agent_list_attribute': [1, 2, 3],
256 'agent_dict_attribute': {
257 'dict_attr_1': [2, 3, 4],
258 'dict_attr_2': [3, 4, 5],
259 'dict_attr_3': [4, 5, 6],
260 }
261 }
262 }
264class TestParseData:
265 def test_parse_data_static_field(self, mock_data):
266 model_string_attr = parse_data(mock_data, ['model_string_attribute'])
267 assert model_string_attr == 'test'
268 model_int_attr = parse_data(mock_data, ['model_int_attribute'])
269 assert model_int_attr == 1
270 agent_string_attr = parse_data(mock_data, ['test_agent', 'agent_string_attribute'])
271 assert agent_string_attr == 'test'
272 agent_int_attr = parse_data(mock_data, ['test_agent', 'agent_int_attribute'])
273 assert agent_int_attr == 1
275 def test_parse_data_missing_field(self, mock_data):
276 missing_value = parse_data(mock_data, ['missing_value'])
277 assert missing_value == None
278 # But still propagate Zeros
279 mock_data['model_int_attribute'] = 0
280 zero_value = parse_data(mock_data, ['model_int_attribute'])
281 assert zero_value == 0
283 def test_parse_data_dict_keys(self, mock_data):
284 single_field = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', 'dict_attr_1'])
285 assert single_field == [2, 3, 4]
286 all_fields = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', '*'])
287 assert all_fields == mock_data['test_agent']['agent_dict_attribute']
288 selected_fields = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', 'dict_attr_1,dict_attr_2'])
289 assert selected_fields == {'dict_attr_1': [2, 3, 4], 'dict_attr_2': [3, 4, 5]}
290 summed_fields = parse_data(mock_data, ['test_agent', 'agent_dict_attribute', 'SUM'])
291 assert summed_fields == [9, 12, 15]
293 def test_parse_data_list(self, mock_data):
294 all_items = parse_data(mock_data, ['test_agent', 'agent_list_attribute', '*'])
295 assert all_items == [1, 2, 3]
296 single_item = parse_data(mock_data, ['test_agent', 'agent_list_attribute', 1])
297 assert single_item == 2
298 slice_item = parse_data(mock_data, ['test_agent', 'agent_list_attribute', '0:2'])
299 assert slice_item == [1, 2]