Coverage for agent_model/util.py: 94%

157 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-05-04 13:14 +0700

1import json, copy, operator, math 

2from pathlib import Path 

3import numpy as np 

4 

5# DATA HANDLING 

6def load_data_file(fname, data_dir=None): 

7 """Load data file from data directory.""" 

8 if data_dir is None: 

9 # Get the absolute path of the directory containing the current script 

10 script_dir = Path(__file__).resolve().parent.parent 

11 data_dir = script_dir / 'data_files' 

12 else: 

13 data_dir = Path(data_dir) 

14 assert data_dir.exists(), f'Data directory does not exist: {data_dir}' 

15 data_file = data_dir / fname 

16 assert data_file.exists(), f'Data file does not exist: {data_file}' 

17 with open(data_file, 'r') as f: 

18 data = json.load(f) 

19 return data 

20 

21def get_default_agent_data(agent): 

22 """Return the relevant dict from default agent_desc.json""" 

23 default_agent_desc = load_data_file('agent_desc.json') 

24 if agent in default_agent_desc: 

25 return copy.deepcopy(default_agent_desc[agent]) 

26 return None 

27 

28def get_default_currency_data(): 

29 """Load default currency_desc.json and convert to new structure""" 

30 currencies = {} 

31 currency_desc = load_data_file('currency_desc.json') 

32 for currency_class, class_currencies in currency_desc.items(): 

33 for currency, currency_data in class_currencies.items(): 

34 currencies[currency] = currency_data 

35 currencies[currency]['currency_type'] = 'currency' 

36 currencies[currency]['class'] = currency_class 

37 currencies[currency_class] = {'currency_type': 'class', 

38 'currencies': list(class_currencies.keys())} 

39 return currencies 

40 

41def merge_json(default, to_merge): 

42 """Merge two objects of arbitrary depth/elements""" 

43 if isinstance(to_merge, dict): 

44 for k, v in to_merge.items(): 

45 default[k] = v if k not in default else merge_json(default[k], v) 

46 return default 

47 elif isinstance(to_merge, list): 

48 if len(to_merge) == 0: 

49 return default 

50 elif isinstance(to_merge[0], dict): 

51 return to_merge 

52 elif isinstance(to_merge[0], (str, int, float)): 

53 return list(set(default).union(set(to_merge))) 

54 else: 

55 raise ValueError(f'Cannot merge lists of type {type(to_merge[0])}') 

56 elif isinstance(to_merge, (str, int, float, bool)): 

57 return to_merge 

58 

59def recursively_clear_lists(r): 

60 if isinstance(r, (int, float, str)): 

61 return r 

62 elif isinstance(r, dict): 

63 return {k: recursively_clear_lists(v) for k, v in r.items()} 

64 elif isinstance(r, list): 

65 return [] 

66 

67def recursively_check_required_kwargs(given, required): 

68 for key, value in required.items(): 

69 if key not in given: 

70 raise ValueError(f'{key} not found in {given}') 

71 if isinstance(value, dict): 

72 recursively_check_required_kwargs(given[key], value) 

73 

74# LIMIT FUNCTIONS (THRESHOLD AND CRITERIA) 

75operator_dict = { 

76 '>': operator.gt, '<': operator.lt, 

77 '>=': operator.ge, '<=': operator.le, 

78 '=': operator.eq, '!=': operator.ne, 

79} 

80def evaluate_reference(agent, reference): 

81 """Evaluate a reference dict and return a boolean 

82 

83 Supported path elements: 

84 - 'grown': from attributes 

85 - 'in_co2_ratio': ratio (0-1) of co2 to total class (atmosphere) from first connected agent 

86 """ 

87 path, limit, value = reference['path'], reference['limit'], reference['value'] 

88 ref_agent = agent 

89 # Parse connected agent 

90 if path.startswith('in_') or path.startswith('out_'): 

91 # Evaluate connections by direction/currency 

92 direction, remainder = path.split('_', 1) 

93 if path.endswith('_ratio'): 

94 currency = '_'.join(remainder.split('_')[:-1]) 

95 else: 

96 currency = remainder 

97 conns = agent.flows[direction][currency]['connections'] 

98 updated_reference = {**reference, 'path': remainder} 

99 results = (evaluate_reference(agent.model.agents[c], updated_reference) for c in conns) 

100 # Return group eval connections 

101 if 'connections' in reference and reference['connections'] == 'all': 

102 return all(results) 

103 return any(results) 

104 # Parse field 

105 if path in ref_agent.attributes: 

106 target = ref_agent.attributes[path] 

107 elif path in ref_agent.storage: 

108 target = ref_agent.storage[path] 

109 elif path.endswith('_ratio'): 

110 currency = path[:-6] 

111 currency_data = ref_agent.model.currencies[currency] 

112 total = sum(ref_agent.view(currency_data['class']).values()) 

113 target = 0 if not total else ref_agent.storage[currency] / total 

114 # Evaluate 

115 return operator_dict[limit]( 

116 round(target, agent.model.floating_point_precision), 

117 round(value, agent.model.floating_point_precision)) 

118 

119# GROWTH FUNCTIONS 

120 

121def pdf(_x, std, cache={}): 

122 """Return y-value of normal distribution at x-value for mean=0""" 

123 if (_x, std) not in cache: 

124 numerator = math.exp(-1 * (_x ** 2) / (2 * (std ** 2))) 

125 denominator = math.sqrt(2 * math.pi) * std 

126 cache[(_x, std)] = numerator / denominator 

127 return cache[(_x, std)] 

128 

129def pdf_mean(std, center, n_samples, cache={}): 

130 """Calculate the mean y-value of the pdf""" 

131 if (std, center) not in cache: 

132 x_vals = [i/n_samples for i in range(n_samples)] 

133 y_vals = [pdf((x - center) / std, std) for x in x_vals] 

134 cache[(std, center)] = sum(y_vals) / n_samples 

135 return cache[(std, center)] 

136 

137def sample_norm(rate, std=math.pi/10, center=0.5, n_samples=100): 

138 """Return y-value of normal distribution at x-value, such mean(y) = 1 

139  

140 Arguments: 

141 rate: x-value to sample at 

142 std: standard deviation of normal distribution 

143 center: x-value to center the distribution at 

144 n_samples: number of samples to use for mean calculation 

145 """ 

146 if any(v < 0 or v > 1 for v in (rate, std, center)): 

147 raise ValueError('rate, std, and center must be between 0 and 1.') 

148 # Shift x-value to center at 0 

149 x = (rate - center) / std 

150 # Calculate y-value 

151 y = pdf(x, std) 

152 # Normalize y-value to mean of 1 

153 y_mean = pdf_mean(std, center, n_samples) 

154 return y / y_mean 

155 

156def sample_clipped_norm(rate, factor=2, **kwargs): 

157 """Return y-value of normal distribution at x-value, clipped at center 

158  

159 From sample_norm, multiply all values by factor, clip at original max, 

160 then scale to max=1. 

161 

162 Arguments: 

163 rate: x-value to sample at 

164 factor: factor to multiply the normal distribution by 

165 """ 

166 norm_value = sample_norm(rate, **kwargs) # Get the norm value 

167 center = kwargs.get('center', 0.5) 

168 y_max = sample_norm(center, **kwargs) # Get max value for that curve 

169 norm_value *= factor # Scale value by factor 

170 clip_value = min(norm_value, y_max) # Clip at original max 

171 return clip_value / y_max # Scale to max=1 

172 

173def sample_sigmoid(rate, min_value=0, max_value=1, steepness=1, center=0.5): 

174 """return the sigmoid value""" 

175 x = steepness * 20 * (rate - center) 

176 y = 1 / (1 + np.exp(-x)) 

177 scaled = y * (max_value - min_value) 

178 shifted = scaled + min_value 

179 return shifted 

180 

181def sample_switch(rate, min_value=0, max_value=1, center=0.5, duration=0.5): 

182 """return the switch value""" 

183 if rate > center - duration / 2 and rate < center + duration / 2: 

184 return max_value 

185 return min_value 

186 

187def evaluate_growth(agent, mode, params): 

188 if mode == 'daily': 

189 rate = agent.model.time.hour / 24 

190 elif mode == 'lifetime': 

191 rate = agent.attributes['age'] / agent.properties['lifetime']['value'] 

192 growth_func = { 

193 'norm': sample_norm, 

194 'sigmoid': sample_sigmoid, 

195 'clipped': sample_clipped_norm, 

196 'switch': sample_switch 

197 }[params['type']] 

198 return growth_func(rate, **{k: v for k, v in params.items() if k != 'type'}) 

199 

200# WORKING WITH OUTPUTS 

201 

202def parse_data(data, path): 

203 """Recursive function to extract data at path from arbitrary object""" 

204 if not data and data != 0: 

205 return None 

206 elif len(path) == 0: 

207 return 0 if data is None else data 

208 # Shift the first element of path, past on the rest of the path 

209 index, *remainder = path 

210 # LISTS 

211 if isinstance(data, list): 

212 # All Items 

213 if index == '*': 

214 parsed = [parse_data(d, remainder) for d in data] 

215 return [d for d in parsed if d is not None] 

216 # Single index 

217 elif isinstance(index, int): 

218 return parse_data(data[index], remainder) 

219 # Range i:j (string) 

220 else: 

221 start, end = [int(i) for i in index.split(':')] 

222 return [parse_data(d, remainder) for d in data[start:end]] 

223 # DICTS 

224 elif isinstance(data, dict): 

225 # All items, either a dict ('*') or a number ('SUM') 

226 if index in {'*', 'SUM'}: 

227 parsed = [parse_data(d, remainder) for d in data.values()] 

228 output = {k: v for k, v in zip(data.keys(), parsed) if v or v == 0} 

229 if len(output) == 0: 

230 return None 

231 elif index == '*': 

232 return output 

233 else: 

234 if isinstance(next(iter(output.values())), list): 

235 return [sum(x) for x in zip(*output.values())] 

236 else: 

237 return sum(output.values()) 

238 # Single Key 

239 elif index in data: 

240 return parse_data(data[index], remainder) 

241 # Comma-separated list of keys. Return an object with all. 

242 elif isinstance(index, str): 

243 indices = [i.strip() for i in index.split(',') if i in data] 

244 parsed = [parse_data(data[i], remainder) for i in indices] 

245 output = {k: v for k, v in zip(indices, parsed) if v or v == 0} 

246 return output if len(output) > 0 else None