diff --git a/.lintrunner.toml b/.lintrunner.toml index 254e287f98..cd8a8d535e 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -299,7 +299,7 @@ include_patterns = [ # 'exir/**/*.py', # 'extension/**/*.py', 'kernels/**/*.py', - # 'profiler/**/*.py', + 'profiler/**/*.py', 'runtime/**/*.py', 'scripts/**/*.py', # 'test/**/*.py', @@ -310,6 +310,7 @@ exclude_patterns = [ 'third-party/**', '**/third-party/**', 'scripts/check_binary_dependencies.py', + 'profiler/test/test_profiler_e2e.py', ] command = [ 'python', diff --git a/.mypy.ini b/.mypy.ini index 171f594716..bb1d574ab5 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -40,6 +40,9 @@ follow_untyped_imports = True [mypy-executorch.kernels.*] follow_untyped_imports = True +[mypy-executorch.profiler.*] +follow_untyped_imports = True + [mypy-executorch.runtime.*] follow_untyped_imports = True diff --git a/profiler/parse_profiler_results.py b/profiler/parse_profiler_results.py index 3fc1a69176..d88191d567 100644 --- a/profiler/parse_profiler_results.py +++ b/profiler/parse_profiler_results.py @@ -9,9 +9,9 @@ from collections import OrderedDict from enum import Enum -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple -from prettytable import PrettyTable +from prettytable import PrettyTable # type: ignore[import-not-found] # This version number should match the one defined in profiler.h ET_PROF_VER = 0x00000001 @@ -89,8 +89,7 @@ class ProfileEvent: duration: List[float] chain_idx: int = -1 instruction_idx: int = -1 - # pyre-ignore[8]: Incompatible attribute type - stacktrace: str = None + stacktrace: Optional[str] = None @dataclasses.dataclass @@ -134,8 +133,8 @@ def parse_prof_blocks( # Iterate through all the profiling blocks data that have been grouped by name. for name, data_list in prof_blocks.items(): - prof_data_list = [] - mem_prof_data_list = [] + prof_data_list: List[ProfileEvent] = [] + mem_prof_data_list: List[MemAllocation] = [] # Each entry in data_list is a tuple in which the first entry is profiling data # and the second entry is memory allocation data, also each entry in data_list # represents one iteration of a code block. @@ -168,13 +167,13 @@ def parse_prof_blocks( # Group all the memory allocation events based on the allocator they were # allocated from. - alloc_sum_dict = OrderedDict() + alloc_sum_dict: OrderedDict[int, int] = OrderedDict() for alloc in mem_prof_data_list: alloc_sum_dict[alloc.allocator_id] = ( alloc_sum_dict.get(alloc.allocator_id, 0) + alloc.allocation_size ) - mem_prof_sum_list = [] + mem_prof_sum_list: List[MemEvent] = [] for allocator_id, allocation_size in alloc_sum_dict.items(): mem_prof_sum_list.append( MemEvent(allocator_dict[allocator_id], allocation_size) @@ -243,7 +242,9 @@ def deserialize_profile_results( prof_allocator_struct_size = struct.calcsize(ALLOCATOR_STRUCT_FMT) prof_allocation_struct_size = struct.calcsize(ALLOCATION_STRUCT_FMT) prof_result_struct_size = struct.calcsize(PROF_RESULT_STRUCT_FMT) - prof_blocks = OrderedDict() + prof_blocks: OrderedDict[ + str, List[Tuple[List[ProfileData], List[MemAllocation]]] + ] = OrderedDict() allocator_dict = {} base_offset = 0 @@ -375,19 +376,19 @@ def profile_aggregate_framework_tax( prof_framework_tax = OrderedDict() for name, prof_data_list in prof_data.items(): - execute_max = [] - kernel_and_delegate_sum = [] + execute_max: List[int] = [] + kernel_and_delegate_sum: List[int] = [] for d in prof_data_list: if "Method::execute" in d.name: - execute_max = max(execute_max, d.duration) + execute_max = max(execute_max, d.duration) # type: ignore[arg-type] if "native_call" in d.name or "delegate_execute" in d.name: for idx in range(len(d.duration)): if idx < len(kernel_and_delegate_sum): - kernel_and_delegate_sum[idx] += d.duration[idx] + kernel_and_delegate_sum[idx] += d.duration[idx] # type: ignore[call-overload] else: - kernel_and_delegate_sum.append(d.duration[idx]) + kernel_and_delegate_sum.append(d.duration[idx]) # type: ignore[arg-type] if len(execute_max) == 0 or len(kernel_and_delegate_sum) == 0: continue @@ -408,10 +409,9 @@ def profile_aggregate_framework_tax( def profile_framework_tax_table( prof_framework_tax_data: Dict[str, ProfileEventFrameworkTax] -): - tables = [] +) -> List[PrettyTable]: + tables: List[PrettyTable] = [] for name, prof_data_list in prof_framework_tax_data.items(): - tables = [] table_agg = PrettyTable() table_agg.title = name + " framework tax calculations" diff --git a/profiler/profiler_results_cli.py b/profiler/profiler_results_cli.py index e02f516637..84d279ed59 100644 --- a/profiler/profiler_results_cli.py +++ b/profiler/profiler_results_cli.py @@ -7,7 +7,7 @@ import argparse import sys -from executorch.profiler.parse_profiler_results import ( +from executorch.profiler.parse_profiler_results import ( # type: ignore[import-not-found] deserialize_profile_results, mem_profile_table, profile_aggregate_framework_tax,