Skip to content

Commit

Permalink
[cherrypick-beta-2.0]Fix kl_div,conv and summary api bug (#27195)
Browse files Browse the repository at this point in the history
* fix some bug
  • Loading branch information
LielinJiang authored Sep 8, 2020
1 parent ed52b00 commit 264e76c
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 32 deletions.
6 changes: 5 additions & 1 deletion paddle/fluid/operators/kldiv_loss_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ class KLDivLossKernel : public framework::OpKernel<T> {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
auto output_sum = output.sum();
loss_t.device(place) = output_sum / output_sum.constant(n);
if (n > 0) {
loss_t.device(place) = output_sum / output_sum.constant(n);
} else {
loss_t.device(place) = output_sum;
}
} else if ("mean" == reduction) {
loss_t.device(place) = output.mean();
} else if ("sum" == reduction) {
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ def kldiv_loss(x, target, reduction):
loss = np.where(target >= 0, output, np.zeros_like(x))

if reduction == "batchmean":
return loss.sum() / x.shape[0]
if len(x.shape) > 0:
return loss.sum() / x.shape[0]
else:
return loss.sum()
if reduction == "mean":
return loss.mean()
if reduction == "sum":
Expand Down Expand Up @@ -93,6 +96,9 @@ def run_kl_loss(self, reduction, shape=(5, 20)):
def test_kl_loss_batchmean(self):
self.run_kl_loss('batchmean')

def test_kl_loss_batchmean_shape(self):
self.run_kl_loss('batchmean', ())

def test_kl_loss_mean(self):
self.run_kl_loss('mean')

Expand Down
9 changes: 7 additions & 2 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,8 +1868,13 @@ def summary(self, input_size=None, batch_size=None, dtype=None):
print(params_info)
"""

return summary(self.network, self._inputs, batch_size, dtype)
assert (input_size is not None or self._inputs is not None
), "'input_size' or 'self._input' must be set"
if input_size is not None:
_input_size = input_size
else:
_input_size = self._inputs
return summary(self.network, _input_size, batch_size, dtype)

def _verify_spec(self, specs, is_input=False):
out_specs = []
Expand Down
79 changes: 61 additions & 18 deletions python/paddle/hapi/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
import numpy as np
import numbers

import paddle
import paddle.nn as nn
Expand Down Expand Up @@ -86,8 +88,10 @@ def forward(self, inputs):
elif isinstance(input_size, list):
_input_size = []
for item in input_size:
if isinstance(item, int):
item = (item, )
assert isinstance(item,
(list, InputSpec)), 'When input_size is list, \
(tuple, InputSpec)), 'When input_size is list, \
expect item in input_size is a tuple or InputSpec, but got {}'.format(
type(item))

Expand All @@ -97,12 +101,19 @@ def forward(self, inputs):
batch_size = item.shape[0]
else:
_input_size.append(item)
elif isinstance(input_size, int):
_input_size = (input_size, )
else:
_input_size = input_size

if batch_size is None:
batch_size = -1

if not paddle.in_dynamic_mode():
warnings.warn(
"Your model was created in static mode, this may not get correct summary information!"
)

result, params_info = summary_string(net, _input_size, batch_size, dtypes)
print(result)

Expand All @@ -117,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):

depth = len(list(model.sublayers()))

def register_hook(module):
def hook(module, input, output):
class_name = str(module.__class__).split(".")[-1].split("'")[0]
def register_hook(layer):
def hook(layer, input, output):
class_name = str(layer.__class__).split(".")[-1].split("'")[0]

try:
module_idx = int(module._full_name.split('_')[-1])
layer_idx = int(layer._full_name.split('_')[-1])
except:
module_idx = len(summary)
layer_idx = len(summary)

m_key = "%s-%i" % (class_name, module_idx + 1)
m_key = "%s-%i" % (class_name, layer_idx + 1)
summary[m_key] = OrderedDict()
summary[m_key]["input_shape"] = list(input[0].shape)
summary[m_key]["input_shape"][0] = batch_size
Expand All @@ -138,23 +149,50 @@ def hook(module, input, output):
summary[m_key]["output_shape"][0] = batch_size

params = 0
if hasattr(module, "weight"):
params += np.prod(module.weight.shape)
summary[m_key]["trainable"] = module.weight.trainable or (
not module.weight.stop_gradient)
if hasattr(module, "bias"):
params += np.prod(module.bias.shape)

if paddle.in_dynamic_mode():
layer_state_dict = layer._parameters
else:
layer_state_dict = layer.state_dict()

for k, v in layer_state_dict.items():
params += np.prod(v.shape)

try:
if (getattr(getattr(layer, k), 'trainable')) and (
not getattr(getattr(layer, k), 'stop_gradient')):
summary[m_key]["trainable"] = True
else:
summary[m_key]["trainable"] = False
except:
summary[m_key]["trainable"] = True

summary[m_key]["nb_params"] = params

if (not isinstance(module, nn.Sequential) and
not isinstance(module, nn.LayerList) and
(not (module == model) or depth < 1)):
if (not isinstance(layer, nn.Sequential) and
not isinstance(layer, nn.LayerList) and
(not (layer == model) or depth < 1)):

hooks.append(layer.register_forward_post_hook(hook))

def _check_input_size(input_sizes):
for input_size in input_sizes:
for item in input_size:
if not isinstance(item, numbers.Number):
raise TypeError(
"Expected item in input size be a number, but got {}".
format(type(item)))

hooks.append(module.register_forward_post_hook(hook))
if item <= 0:
raise ValueError(
"Expected item in input size greater than zero, but got {}".
format(item))

if isinstance(input_size, tuple):
input_size = [input_size]

_check_input_size(input_size)

x = [
paddle.rand(
[2] + list(in_size), dtype=dtype)
Expand Down Expand Up @@ -193,7 +231,12 @@ def hook(module, input, output):
"{0:,}".format(summary[layer]["nb_params"]), )
total_params += summary[layer]["nb_params"]

total_output += np.prod(summary[layer]["output_shape"])
try:
total_output += np.prod(summary[layer]["output_shape"])
except:
for output_shape in summary[layer]["output_shape"]:
total_output += np.prod(output_shape)

if "trainable" in summary[layer]:
if summary[layer]["trainable"] == True:
trainable_params += summary[layer]["nb_params"]
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/nn/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,10 +780,10 @@ def kl_div(input, label, reduction='mean', name=None):
input = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N]
# 'batchmean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_tensor(input),
paddle.to_tensor(target), reduction='batchmean')
# shape=[5]
# shape=[1]
# 'mean' reduction, loss shape will be [1]
pred_loss = F.kl_div(paddle.to_tensor(input),
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ def __init__(self,
bias_attr=bias_attr,
data_format=data_format)

def forward(self, x, output_size):
def forward(self, x, output_size=None):
if output_size is None:
output_padding = self.output_padding
else:
Expand Down
17 changes: 10 additions & 7 deletions python/paddle/nn/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,13 @@ class KLDivLoss(fluid.dygraph.Layer):
$$l(x, y) = y * (\log(y) - x)$$
Parameters:
reduction (str, optional): Indicate how to average the loss,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
reduction (Tensor): Indicate how to average the loss,
the candicates are ``'none'`` | ``'batchmean'`` | ``'mean'`` | ``'sum'``.
If `reduction` is ``'mean'``, the reduced mean loss is returned;
If `reduction` is ``'batchmean'``, the sum loss divided by batch size is returned;
if `reduction` is ``'sum'``, the reduced sum loss is returned;
if `reduction` is ``'none'``, no reduction will be apllied.
Default is ``'mean'``.
Shape:
Expand All @@ -654,11 +657,11 @@ class KLDivLoss(fluid.dygraph.Layer):
x = np.random.uniform(-10, 10, shape).astype('float32')
target = np.random.uniform(-10, 10, shape).astype('float32')
# 'batchmean' reduction, loss shape will be [N]
# 'batchmean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='batchmean')
pred_loss = kldiv_criterion(paddle.to_tensor(x),
paddle.to_tensor(target))
# shape=[5]
# shape=[1]
# 'mean' reduction, loss shape will be [1]
kldiv_criterion = nn.KLDivLoss(reduction='mean')
Expand All @@ -684,7 +687,7 @@ def __init__(self, reduction='mean'):
self.reduction = reduction

def forward(self, input, label):
out = paddle.nn.functional.kl_div(input, label, self.reduction)
out = F.kl_div(input, label, self.reduction)
return out


Expand Down
22 changes: 22 additions & 0 deletions python/paddle/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,28 @@ def _get_param_from_state_dict(state_dict):
np.testing.assert_allclose(params_info['total_params'], gt_params)
print(params_info)

model.summary(input_size=(20))
model.summary(input_size=[(20)])
model.summary(input_size=(20), batch_size=2)

def test_summary_nlp(self):
paddle.enable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2))

def test_summary_error(self):
with self.assertRaises(TypeError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, '2'))

with self.assertRaises(ValueError):
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (-1, -1))

paddle.disable_static()
nlp_net = paddle.nn.GRU(input_size=2, hidden_size=3, num_layers=3)
paddle.summary(nlp_net, (1, 2))

def test_export_deploy_model(self):
for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None
Expand Down

1 comment on commit 264e76c

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 264e76c Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 Commit ID: 264e76c contains failed CI.

  • Failed: MAC_Python3_Build (Paddle Mac Build)

Please sign in to comment.