Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.39】为 Paddle 代码转换工具新增 API 转换规则(第 6 组) #477

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -9277,6 +9277,20 @@
"output_size"
]
},
"torch.nn.AdaptiveLogSoftmaxWithLoss": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.AdaptiveLogSoftmaxWithLoss",
"min_input_args": 3,
"args_list": [
"in_features",
"n_classes",
"cutoffs",
"div_value",
"head_bias",
"device",
"dtype"
]
},
"torch.nn.AdaptiveMaxPool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.AdaptiveMaxPool1D",
Expand Down Expand Up @@ -9490,6 +9504,17 @@
"groups"
]
},
"torch.nn.CircularPad3d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Pad3D",
"min_input_args": 1,
"args_list": [
"padding"
],
"paddle_default_kwargs": {
"mode": "'circular'"
}
},
"torch.nn.ConstantPad1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Pad1D",
Expand Down Expand Up @@ -10077,6 +10102,28 @@
],
"min_input_args": 0
},
"torch.nn.LPPool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.LPPool1D",
"min_input_args": 2,
"args_list": [
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
]
},
"torch.nn.LPPool2d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.LPPool2D",
"min_input_args": 2,
"args_list": [
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
]
},
"torch.nn.LSTM": {
"Matcher": "RNNMatcher",
"paddle_api": "paddle.nn.LSTM",
Expand Down Expand Up @@ -11045,6 +11092,20 @@
},
"min_input_args": 0
},
"torch.nn.Softmin": {
"Matcher": "SoftminMatcher",
"paddle_api": "paddle.nn.Softmax",
"args_list": [
"dim"
],
"kwargs_change": {
"dim": "axis"
},
"paddle_default_kwargs": {
"axis": 0
},
"min_input_args": 0
},
"torch.nn.Softplus": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Softplus",
Expand Down Expand Up @@ -11827,6 +11888,20 @@
],
"min_input_args": 2
},
"torch.nn.functional.feature_alpha_dropout": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.feature_alpha_dropout",
"min_input_args": 1,
"args_list": [
"input",
"p",
"training",
"inplace"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.fold": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.fold",
Expand Down Expand Up @@ -12206,6 +12281,36 @@
"input": "x"
}
},
"torch.nn.functional.lp_pool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.lp_pool1d",
"min_input_args": 2,
"args_list": [
"input",
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.lp_pool2d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.lp_pool2d",
"min_input_args": 2,
"args_list": [
"input",
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.margin_ranking_loss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.margin_ranking_loss",
Expand Down Expand Up @@ -12795,6 +12900,19 @@
"input": "x"
}
},
"torch.nn.functional.threshold_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.thresholded_relu_",
"min_input_args": 3,
"args_list": [
"input",
"threshold",
"value"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.triplet_margin_loss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.triplet_margin_loss",
Expand Down Expand Up @@ -13186,6 +13304,19 @@
},
"min_input_args": 1
},
"torch.nn.utils.parametrizations.weight_norm": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.utils.weight_norm",
"args_list": [
"module",
"name",
"dim"
],
"kwargs_change": {
"module": "layer"
},
"min_input_args": 1
},
"torch.nn.utils.remove_weight_norm": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.utils.remove_weight_norm",
Expand Down Expand Up @@ -13535,6 +13666,40 @@
"lr": "learning_rate"
}
},
"torch.optim.NAdam": {
"Matcher": "OptimAdamMatcher",
"paddle_api": "paddle.optimizer.NAdam",
"min_input_args": 1,
"args_list": [
"params",
"lr",
"betas",
"eps",
"weight_decay",
"momentum_decay",
"decoupled_weight_decay",
"*",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"unsupport_args": [
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
"decoupled_weight_decay",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"kwargs_change": {
"params": "parameters",
"lr": "learning_rate",
"eps": "epsilon"
},
"paddle_default_kwargs": {
"weight_decay": 0.0
}
},
"torch.optim.Optimizer": {
"Matcher": "OptimOptimizerMatcher",
"paddle_api": "paddle.optimizer.Optimizer",
Expand Down Expand Up @@ -13575,6 +13740,39 @@
"torch.optim.Optimizer.zero_grad": {
"min_input_args": 0
},
"torch.optim.RAdam": {
"Matcher": "OptimAdamMatcher",
"paddle_api": "paddle.optimizer.RAdam",
"min_input_args": 1,
"args_list": [
"params",
"lr",
"betas",
"eps",
"weight_decay",
"decoupled_weight_decay",
"*",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"unsupport_args": [
Asthestarsfalll marked this conversation as resolved.
Show resolved Hide resolved
"decoupled_weight_decay",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"kwargs_change": {
"params": "parameters",
"lr": "learning_rate",
"eps": "epsilon"
},
"paddle_default_kwargs": {
"weight_decay": 0.0
}
},
"torch.optim.RMSprop": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.optimizer.RMSProp",
Expand Down
30 changes: 30 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3921,6 +3921,36 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class SoftminMatcher(SoftmaxMatcher):
def generate_code(self, kwargs):
self.paddle_api = "paddle.nn.Softmin"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在json里不能配吗

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为要修改forward,似乎不能通过json配吧

return super().generate_code(kwargs)

def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
def _get_softmax_dim(axis: int) -> int:
if axis == 0 or axis == 1 or axis == 3:
ret = 0
else:
ret = 1
return ret

def forward(self,x):
if self._axis is None:
return paddle.nn.functional.softmax(x, _get_softmax_dim(x.ndim))
return paddle.nn.functional.softmax(x, self._axis)
setattr(paddle.nn.Softmax, 'forward', forward)

class Softmin(paddle.nn.Softmax):
def forward(self, x):
return super().forward(-x)
setattr(paddle.nn, 'Softmin', Softmin)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好像没有办法获取forward的输入,所以只能这样

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个就不用setattr paddle了,直接调用paddle_aux.Softmin吧,避免别人认为paddle也有paddle.nn.Softmin

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新,本地测试截图:
241025_15h32m45s_screenshot

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本地第一次运行会出现同样的错误,但是第二次运行就正常了,测试了其他使用aux code的文件,也有同样的问题,比如Softmax

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本地第一次运行会出现同样的错误,但是第二次运行就正常了,测试了其他使用aux code的文件,也有同样的问题,比如Softmax

@zhwesky2010

"""
)
return CODE_TEMPLATE


class OptimOptimizerMatcher(BaseMatcher):
def generate_code(self, kwargs):
code = "paddle.optimizer.Optimizer(parameters={}, **{})".format(
Expand Down
83 changes: 83 additions & 0 deletions tests/test_nn_AdaptiveLogSoftmaxWithLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.AdaptiveLogSoftmaxWithLoss")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [2])
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [3], div_value=2.0)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [1], div_value=3.8, head_bias=True)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=8, cutoffs=[5], div_value=3.8, head_bias=True)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)
Loading