add some code
This commit is contained in:
29
managed_components/78__esp-opus/dnn/torch/lossgen/lossgen.py
Normal file
29
managed_components/78__esp-opus/dnn/torch/lossgen/lossgen.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LossGen(nn.Module):
|
||||
def __init__(self, gru1_size=16, gru2_size=16):
|
||||
super(LossGen, self).__init__()
|
||||
|
||||
self.gru1_size = gru1_size
|
||||
self.gru2_size = gru2_size
|
||||
self.dense_in = nn.Linear(2, 8)
|
||||
self.gru1 = nn.GRU(8, self.gru1_size, batch_first=True)
|
||||
self.gru2 = nn.GRU(self.gru1_size, self.gru2_size, batch_first=True)
|
||||
self.dense_out = nn.Linear(self.gru2_size, 1)
|
||||
|
||||
def forward(self, loss, perc, states=None):
|
||||
#print(states)
|
||||
device = loss.device
|
||||
batch_size = loss.size(0)
|
||||
if states is None:
|
||||
gru1_state = torch.zeros((1, batch_size, self.gru1_size), device=device)
|
||||
gru2_state = torch.zeros((1, batch_size, self.gru2_size), device=device)
|
||||
else:
|
||||
gru1_state = states[0]
|
||||
gru2_state = states[1]
|
||||
x = torch.tanh(self.dense_in(torch.cat([loss, perc], dim=-1)))
|
||||
gru1_out, gru1_state = self.gru1(x, gru1_state)
|
||||
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
|
||||
return self.dense_out(gru2_out), [gru1_state, gru2_state]
|
||||
Reference in New Issue
Block a user