1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| def t2v(tau, f, out_features, w, b, w0, b0, arg=None): if arg: v1 = f(torch.matmul(tau, w) + b, arg) else: v1 = f(torch.matmul(tau, w) + b) v2 = torch.matmul(tau, w0) + b0 return torch.cat([v1, v2], 1)
class SineActivation(nn.Module): def __init__(self, in_features, out_features): super(SineActivation, self).__init__() self.out_features = out_features self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1)) self.b0 = nn.parameter.Parameter(torch.randn(in_features, 1)) self.w = nn.parameter.Parameter(torch.randn(in_features, out_features - 1)) self.b = nn.parameter.Parameter(torch.randn(in_features, out_features - 1)) self.f = torch.sin
def forward(self, tau): return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)
class CosineActivation(nn.Module): def __init__(self, in_features, out_features): super(CosineActivation, self).__init__() self.out_features = out_features self.w0 = nn.parameter.Parameter(torch.randn(in_features, 1)) self.b0 = nn.parameter.Parameter(torch.randn(in_features, 1)) self.w = nn.parameter.Parameter(torch.randn(in_features, out_features - 1)) self.b = nn.parameter.Parameter(torch.randn(in_features, out_features - 1)) self.f = torch.cos
def forward(self, tau): return t2v(tau, self.f, self.out_features, self.w, self.b, self.w0, self.b0)
class Time2Vec(nn.Module): def __init__(self, activation, hiddem_dim): super(Time2Vec, self).__init__() if activation == "sin": self.l1 = SineActivation(1, hiddem_dim) elif activation == "cos": self.l1 = CosineActivation(1, hiddem_dim)
self.fc1 = nn.Linear(hiddem_dim, 2)
def forward(self, x): x = self.l1(x) x = self.fc1(x) return x
|