Skip to content

Commit

Permalink
Enable mypy lintrunner, Part 3 (profiler/*)
Browse files Browse the repository at this point in the history
Differential Revision: D67807621

Pull Request resolved: #7494
  • Loading branch information
mergennachin authored Jan 5, 2025
1 parent ae3d558 commit 7010a11
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 19 deletions.
3 changes: 2 additions & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ include_patterns = [
# 'exir/**/*.py',
# 'extension/**/*.py',
'kernels/**/*.py',
# 'profiler/**/*.py',
'profiler/**/*.py',
'runtime/**/*.py',
'scripts/**/*.py',
# 'test/**/*.py',
Expand All @@ -310,6 +310,7 @@ exclude_patterns = [
'third-party/**',
'**/third-party/**',
'scripts/check_binary_dependencies.py',
'profiler/test/test_profiler_e2e.py',
]
command = [
'python',
Expand Down
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ follow_untyped_imports = True
[mypy-executorch.kernels.*]
follow_untyped_imports = True

[mypy-executorch.profiler.*]
follow_untyped_imports = True

[mypy-executorch.runtime.*]
follow_untyped_imports = True

Expand Down
34 changes: 17 additions & 17 deletions profiler/parse_profiler_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from collections import OrderedDict
from enum import Enum

from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple

from prettytable import PrettyTable
from prettytable import PrettyTable # type: ignore[import-not-found]

# This version number should match the one defined in profiler.h
ET_PROF_VER = 0x00000001
Expand Down Expand Up @@ -89,8 +89,7 @@ class ProfileEvent:
duration: List[float]
chain_idx: int = -1
instruction_idx: int = -1
# pyre-ignore[8]: Incompatible attribute type
stacktrace: str = None
stacktrace: Optional[str] = None


@dataclasses.dataclass
Expand Down Expand Up @@ -134,8 +133,8 @@ def parse_prof_blocks(

# Iterate through all the profiling blocks data that have been grouped by name.
for name, data_list in prof_blocks.items():
prof_data_list = []
mem_prof_data_list = []
prof_data_list: List[ProfileEvent] = []
mem_prof_data_list: List[MemAllocation] = []
# Each entry in data_list is a tuple in which the first entry is profiling data
# and the second entry is memory allocation data, also each entry in data_list
# represents one iteration of a code block.
Expand Down Expand Up @@ -168,13 +167,13 @@ def parse_prof_blocks(

# Group all the memory allocation events based on the allocator they were
# allocated from.
alloc_sum_dict = OrderedDict()
alloc_sum_dict: OrderedDict[int, int] = OrderedDict()
for alloc in mem_prof_data_list:
alloc_sum_dict[alloc.allocator_id] = (
alloc_sum_dict.get(alloc.allocator_id, 0) + alloc.allocation_size
)

mem_prof_sum_list = []
mem_prof_sum_list: List[MemEvent] = []
for allocator_id, allocation_size in alloc_sum_dict.items():
mem_prof_sum_list.append(
MemEvent(allocator_dict[allocator_id], allocation_size)
Expand Down Expand Up @@ -243,7 +242,9 @@ def deserialize_profile_results(
prof_allocator_struct_size = struct.calcsize(ALLOCATOR_STRUCT_FMT)
prof_allocation_struct_size = struct.calcsize(ALLOCATION_STRUCT_FMT)
prof_result_struct_size = struct.calcsize(PROF_RESULT_STRUCT_FMT)
prof_blocks = OrderedDict()
prof_blocks: OrderedDict[
str, List[Tuple[List[ProfileData], List[MemAllocation]]]
] = OrderedDict()
allocator_dict = {}
base_offset = 0

Expand Down Expand Up @@ -375,19 +376,19 @@ def profile_aggregate_framework_tax(
prof_framework_tax = OrderedDict()

for name, prof_data_list in prof_data.items():
execute_max = []
kernel_and_delegate_sum = []
execute_max: List[int] = []
kernel_and_delegate_sum: List[int] = []

for d in prof_data_list:
if "Method::execute" in d.name:
execute_max = max(execute_max, d.duration)
execute_max = max(execute_max, d.duration) # type: ignore[arg-type]

if "native_call" in d.name or "delegate_execute" in d.name:
for idx in range(len(d.duration)):
if idx < len(kernel_and_delegate_sum):
kernel_and_delegate_sum[idx] += d.duration[idx]
kernel_and_delegate_sum[idx] += d.duration[idx] # type: ignore[call-overload]
else:
kernel_and_delegate_sum.append(d.duration[idx])
kernel_and_delegate_sum.append(d.duration[idx]) # type: ignore[arg-type]

if len(execute_max) == 0 or len(kernel_and_delegate_sum) == 0:
continue
Expand All @@ -408,10 +409,9 @@ def profile_aggregate_framework_tax(

def profile_framework_tax_table(
prof_framework_tax_data: Dict[str, ProfileEventFrameworkTax]
):
tables = []
) -> List[PrettyTable]:
tables: List[PrettyTable] = []
for name, prof_data_list in prof_framework_tax_data.items():
tables = []
table_agg = PrettyTable()
table_agg.title = name + " framework tax calculations"

Expand Down
2 changes: 1 addition & 1 deletion profiler/profiler_results_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import argparse
import sys

from executorch.profiler.parse_profiler_results import (
from executorch.profiler.parse_profiler_results import ( # type: ignore[import-not-found]
deserialize_profile_results,
mem_profile_table,
profile_aggregate_framework_tax,
Expand Down

0 comments on commit 7010a11

Please sign in to comment.