Source code for power_cogs.utils.torch_utils
from typing import List
import torch
[docs]def create_linear_layer(
input_dims: int, output_dims: int, bias: bool = True, activation=None
):
linear = torch.nn.Linear(input_dims, output_dims, bias=bias)
if activation is not None:
linear = torch.nn.Sequential(*[linear, activation])
return linear
[docs]def create_linear_network(
input_dims: int, hidden_dims: List[int], output_dims: int, output_activation=None
):
if len(hidden_dims) > 0:
input_layer = create_linear_layer(
input_dims, hidden_dims[0], bias=True, activation=torch.nn.ReLU()
)
hidden_layers = []
for i in range(len(hidden_dims)):
if i == 0:
hidden_layers.append(
create_linear_layer(
hidden_dims[i],
hidden_dims[i],
bias=True,
activation=torch.nn.ReLU(),
)
)
else:
hidden_layers.append(
create_linear_layer(
hidden_dims[i - 1],
hidden_dims[i],
bias=True,
activation=torch.nn.ReLU(),
)
)
output_layer = create_linear_layer(
hidden_dims[-1], output_dims, bias=True, activation=output_activation
)
return torch.nn.Sequential(*[input_layer, *hidden_layers, output_layer])
return create_linear_layer(
input_dims, output_dims, bias=True, activation=output_activation
)