Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeStyle][Ruff][BUAA][D-[7-13]] Fix ruff RUF015 diagnostic for 6 files in python/paddle/ #67359

Merged
merged 11 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions python/paddle/distributed/transpiler/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,11 +1826,14 @@ def _update_dist_lookup_table_vars(
for grad in grad_list
if grad.name != grad_var_name(self.table_name)
]
self.table_param_grad = [
param_grad
for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
self.table_param_grad = next(
(
param_grad
for param_grad in params_grads
if param_grad[0].name == self.table_name
),
None, # Default value if no matching param_grad is found
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要修改原来的逻辑?原来如果找不到会直接报错,而现在会隐藏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那我是直接把 None 去掉吗?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要修改原有逻辑,去掉

)
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.trainer_side_table_grad_list = [
Expand Down Expand Up @@ -2132,12 +2135,15 @@ def _create_table_optimize_block(
table_opt_block = pserver_program._create_block(pre_block_idx)
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op = [
op
for op in self.optimize_ops
if 'Param' in op.input_names
and op.input("Param")[0] == self.table_name
][0]
table_opt_op = next(
(
op
for op in self.optimize_ops
if 'Param' in op.input_names
and op.input("Param")[0] == self.table_name
),
None,
)

origin_param_var = self.origin_program.global_block().vars[
self.table_name
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def to(self, *args, **kwargs):
if len(invalid_keys) != 0:
raise TypeError(
"to() got an unexpected keyword argument "
+ list(invalid_keys)[0]
+ next(iter(invalid_keys), "unknown key")
)

def dtype_first_sig(dtype, blocking=None): ...
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/static/amp/function_overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get(self, *args, **kwargs):
satisfied_function_keys.remove(func_key)
break
if len(satisfied_function_keys) == 1:
key = list(satisfied_function_keys)[0]
key = next(iter(satisfied_function_keys), None)
elif len(args) >= 3 and isinstance(args[2], float):
key = FunctionType.FP16_ONLY
else:
Expand Down
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,9 @@ def get_paddle_extra_install_requirements():
output = subprocess.check_output(['nvcc', '--version']).decode(
'utf-8'
)
version_line = [
line for line in output.split('\n') if 'release' in line
][0]
version_line = next(
(line for line in output.split('\n') if 'release' in line), None
)
version = version_line.split(' ')[-1].split(',')[0]
cuda_major_version = version.split('.')[0]
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions test/sot/test_model_switch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def setUp(self):

def check_mode(self, is_train):
self.assertEqual(len(self.compile_cache.cache), 1)
mode = list(self.compile_cache.cache.values())[
0
].partial_program.training
mode = next(
iter(self.compile_cache.cache.values())
).partial_program.training
self.assertEqual(mode, is_train)

def get_dygraph_out(self, input):
Expand Down
4 changes: 2 additions & 2 deletions test/xpu/test_tril_triu_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_errors1(self):
errmsg = {
"diagonal: TypeError": f"diagonal in {op_type} must be a python Int",
}
expected = list(errmsg.keys())[0]
expected = next(iter(errmsg.keys()))
with self.assertRaisesRegex(
eval(expected.split(':')[-1]), errmsg[expected]
):
Expand All @@ -155,7 +155,7 @@ def test_errors2(self):
errmsg = {
"input: ValueError": f"x shape in {op_type} must be at least 2-D",
}
expected = list(errmsg.keys())[0]
expected = next(iter(errmsg.keys()))
with self.assertRaisesRegex(
eval(expected.split(':')[-1]), errmsg[expected]
):
Expand Down