"Fossies" - the Fresh Open Source Software Archive

Member "rspamd-1.7.3/contrib/torch/nn/AddConstant.lua" (10 Apr 2018, 1531 Bytes) of package /linux/misc/rspamd-1.7.3.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Lua source code syntax highlighting (style: standard) with prefixed line numbers and code folding option. Alternatively you can here view or download the uninterpreted source code file.

    1 local AddConstant, parent = torch.class('nn.AddConstant', 'nn.Module')
    2 
    3 function AddConstant:__init(constant_scalar,ip)
    4    parent.__init(self)
    5    self.constant_scalar = constant_scalar
    6 
    7   -- default for inplace is false
    8    self.inplace = ip or false
    9    if (ip and type(ip) ~= 'boolean') then
   10       error('in-place flag must be boolean')
   11    end
   12 end
   13 
   14 function AddConstant:updateOutput(input)
   15    assert(type(self.constant_scalar) == 'number' or
   16       (torch.isTensor(self.constant_scalar) and input:nDimension() <= 2 and
   17       input:size(input:nDimension()) == self.constant_scalar:size(1)),
   18       'input is not scalar or doesn\'t match with the dimension of constant!')
   19    local tmp
   20    if torch.isTensor(self.constant_scalar) and input:nDimension() == 2 then
   21       local nOutput = self.constant_scalar:size(1)
   22       tmp = self.constant_scalar.new()
   23       tmp:resize(1,nOutput)
   24       tmp:copy(self.constant_scalar)
   25       tmp = tmp:expand(input:size(1),nOutput)
   26    else
   27       tmp = self.constant_scalar
   28    end
   29    if self.inplace then
   30       input:add(tmp)
   31       self.output:set(input)
   32    else
   33       self.output:resizeAs(input)
   34       self.output:copy(input)
   35       self.output:add(tmp)
   36    end
   37    return self.output
   38 end
   39 
   40 function AddConstant:updateGradInput(input, gradOutput)
   41    if self.inplace then
   42       self.gradInput:set(gradOutput)
   43       -- restore previous input value
   44       input:add(-self.constant_scalar)
   45    else
   46       self.gradInput:resizeAs(gradOutput)
   47       self.gradInput:copy(gradOutput)
   48    end
   49    return self.gradInput
   50 end