Skip to content

Commit

Permalink
[CodeStyle][Ruff][BUAA][D-[7-13]] Fix ruff RUF015 diagnostic for 6 …
Browse files Browse the repository at this point in the history
…files in `python/paddle/` (#67359)
  • Loading branch information
MufanColin authored Aug 14, 2024
1 parent 1d42b84 commit b4ae8dd
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions python/paddle/distributed/transpiler/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1826,11 +1826,11 @@ def _update_dist_lookup_table_vars(
for grad in grad_list
if grad.name != grad_var_name(self.table_name)
]
self.table_param_grad = [
self.table_param_grad = next(
param_grad
for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
)
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.trainer_side_table_grad_list = [
Expand Down Expand Up @@ -2132,12 +2132,12 @@ def _create_table_optimize_block(
table_opt_block = pserver_program._create_block(pre_block_idx)
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op = [
table_opt_op = next(
op
for op in self.optimize_ops
if 'Param' in op.input_names
and op.input("Param")[0] == self.table_name
][0]
)

origin_param_var = self.origin_program.global_block().vars[
self.table_name
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def to(self, *args, **kwargs):
if len(invalid_keys) != 0:
raise TypeError(
"to() got an unexpected keyword argument "
+ list(invalid_keys)[0]
+ next(iter(invalid_keys))
)

def dtype_first_sig(dtype, blocking=None): ...
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/static/amp/function_overload.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get(self, *args, **kwargs):
satisfied_function_keys.remove(func_key)
break
if len(satisfied_function_keys) == 1:
key = list(satisfied_function_keys)[0]
key = next(iter(satisfied_function_keys))
elif len(args) >= 3 and isinstance(args[2], float):
key = FunctionType.FP16_ONLY
else:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,9 +1023,9 @@ def get_paddle_extra_install_requirements():
output = subprocess.check_output(['nvcc', '--version']).decode(
'utf-8'
)
version_line = [
version_line = next(
line for line in output.split('\n') if 'release' in line
][0]
)
version = version_line.split(' ')[-1].split(',')[0]
cuda_major_version = version.split('.')[0]
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions test/sot/test_model_switch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def setUp(self):

def check_mode(self, is_train):
self.assertEqual(len(self.compile_cache.cache), 1)
mode = list(self.compile_cache.cache.values())[
0
].partial_program.training
mode = next(
iter(self.compile_cache.cache.values())
).partial_program.training
self.assertEqual(mode, is_train)

def get_dygraph_out(self, input):
Expand Down
4 changes: 2 additions & 2 deletions test/xpu/test_tril_triu_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_errors1(self):
errmsg = {
"diagonal: TypeError": f"diagonal in {op_type} must be a python Int",
}
expected = list(errmsg.keys())[0]
expected = next(iter(errmsg.keys()))
with self.assertRaisesRegex(
eval(expected.split(':')[-1]), errmsg[expected]
):
Expand All @@ -155,7 +155,7 @@ def test_errors2(self):
errmsg = {
"input: ValueError": f"x shape in {op_type} must be at least 2-D",
}
expected = list(errmsg.keys())[0]
expected = next(iter(errmsg.keys()))
with self.assertRaisesRegex(
eval(expected.split(':')[-1]), errmsg[expected]
):
Expand Down

0 comments on commit b4ae8dd

Please sign in to comment.