From f1f9465e473d1a2b234f0a193609ce3e1dce3437 Mon Sep 17 00:00:00 2001 From: Carles Pey Date: Sat, 23 Nov 2024 14:59:35 -0500 Subject: [PATCH] Refactor variable usage and statistics collection Signed-off-by: Carles Pey --- chipsec/modules/tools/smm/smm_ptr.py | 118 +++++++++++++++------------ 1 file changed, 64 insertions(+), 54 deletions(-) diff --git a/chipsec/modules/tools/smm/smm_ptr.py b/chipsec/modules/tools/smm/smm_ptr.py index 2185e0a1f1..03a35e0181 100644 --- a/chipsec/modules/tools/smm/smm_ptr.py +++ b/chipsec/modules/tools/smm/smm_ptr.py @@ -199,21 +199,54 @@ def get_info(self): return f'duration {self.duration} code {self.code:02X} data {self.data:02X} ({gprs_info(self.gprs)})' +class smi_stats: + def __init__(self): + self.clear() + + def clear(self): + self.count = 0 + self.mean = 0 + self.m2 = 0 + self.stdev = 0 + self.outliers = 0 + + # + # Computes the standard deviation using the Welford's online algorithm + # + def update_stats(self, duration): + self.count += 1 + difference = duration - self.mean + self.mean += difference / self.count + self.m2 += difference * (duration - self.mean) + variance = self.m2 / self.count + self.stdev = math.sqrt(variance) + + def get_info(self): + info = f'average {round(self.mean)} stddev {self.stdev:.2f} checked {self.count}' + return info + + # + # Combines the statistics of the two data sets using parallel variance computation + # + def combine(self, partial): + self.outliers += partial.outliers + total_count = self.count + partial.count + difference = partial.mean - self.mean + self.mean = (self.mean * self.count + partial.mean * partial.count) / total_count + self.m2 += partial.m2 + difference**2 * self.count * partial.count / total_count + self.count = total_count + variance = self.m2 / self.count + self.stdev = math.sqrt(variance) + + class scan_track: def __init__(self): + self.current_smi_stats = smi_stats() + self.history_smi_stats = smi_stats() self.clear() - self.hist_smi_duration = 0 - self.hist_smi_num = 0 - self.outliers_hist = 0 self.helper = OsHelper().get_default_helper() self.helper.init() self.smi_count = self.get_smi_count() - self.needs_calibration = True - self.calib_samples = 0 - self.stdev = 0 - self.m2 = 0 - self.stdev_hist = 0 - self.m2_hist = 0 def __del__(self): self.helper.close() @@ -251,73 +284,47 @@ def find_address_in_regs(self, gprs): return key def clear(self): - self.max = smi_info(0) - self.min = smi_info(2**32 - 1) self.outlier = smi_info(0) - self.avg_smi_duration = 0 - self.avg_smi_num = 0 - self.outliers = 0 self.code = None - self.confirmed = False + self.contents_changed = False self.needs_calibration = True self.calib_samples = 0 - self.stdev = 0 - self.m2 = 0 + self.current_smi_stats.clear() - def add(self, duration, code, data, gprs, confirmed=False): + def add(self, duration, code, data, gprs, contents_changed=False): if not self.code: self.code = code outlier = self.is_outlier(duration) if not outlier: - self.update_stdev(duration) - if duration > self.max.duration: - self.max.update(duration, code, data, gprs.copy()) - elif duration < self.min.duration: - self.min.update(duration, code, data, gprs.copy()) + self.current_smi_stats.update_stats(duration) elif self.is_slow_outlier(duration): - self.outliers += 1 - self.outliers_hist += 1 + self.current_smi_stats.outliers += 1 self.outlier.update(duration, code, data, gprs.copy()) - self.confirmed = confirmed - - # - # Computes the standard deviation using the Welford's online algorithm - # - def update_stdev(self, value): - self.avg_smi_num += 1 - self.hist_smi_num += 1 - difference = value - self.avg_smi_duration - difference_hist = value - self.hist_smi_duration - self.avg_smi_duration += difference / self.avg_smi_num - self.hist_smi_duration += difference_hist / self.hist_smi_num - self.m2 += difference * (value - self.avg_smi_duration) - self.m2_hist += difference_hist * (value - self.hist_smi_duration) - variance = self.m2 / self.avg_smi_num - variance_hist = self.m2_hist / self.hist_smi_num - self.stdev = math.sqrt(variance) - self.stdev_hist = math.sqrt(variance_hist) + self.contents_changed = contents_changed def update_calibration(self, duration): if not self.needs_calibration: return - self.update_stdev(duration) + self.current_smi_stats.update_stats(duration) self.calib_samples += 1 if self.calib_samples >= SCAN_CALIB_SAMPLES: self.needs_calibration = False def is_slow_outlier(self, value): ret = False - if value > self.avg_smi_duration + OUTLIER_STD_DEV * self.stdev: + if value > self.current_smi_stats.mean + OUTLIER_STD_DEV * self.current_smi_stats.stdev: ret = True - if value > self.hist_smi_duration + OUTLIER_STD_DEV * self.stdev_hist: + if self.history_smi_stats.count and \ + value > self.history_smi_stats.mean + OUTLIER_STD_DEV * self.history_smi_stats.stdev: ret = True return ret def is_fast_outlier(self, value): ret = False - if value < self.avg_smi_duration - OUTLIER_STD_DEV * self.stdev: + if value < self.current_smi_stats.mean - OUTLIER_STD_DEV * self.current_smi_stats.stdev: ret = True - if value < self.hist_smi_duration - OUTLIER_STD_DEV * self.stdev_hist: + if self.history_smi_stats.count and \ + value < self.history_smi_stats.mean - OUTLIER_STD_DEV * self.history_smi_stats.stdev: ret = True return ret @@ -332,18 +339,17 @@ def is_outlier(self, value): return ret def skip(self): - return self.outliers or self.confirmed + return self.current_smi_stats.outliers or self.contents_changed def found_outlier(self): - return bool(self.outliers) + return bool(self.current_smi_stats.outliers) def get_total_outliers(self): - return self.outliers_hist + return self.history_smi_stats.outliers def get_info(self): - avg = self.avg_smi_duration or self.hist_smi_duration - info = f'average {round(avg)} stddev {self.stdev:.2f} checked {self.avg_smi_num + self.outliers}' - if self.outliers: + info = self.current_smi_stats.get_info() + if self.current_smi_stats.outliers: info += f'\n Identified outlier: {self.outlier.get_info()}' return info @@ -354,6 +360,9 @@ def log_smi_result(self, logger): else: logger.log(f'[*] {msg}') + def update_history_stats(self): + self.history_smi_stats.combine(self.current_smi_stats) + class smi_desc: def __init__(self): @@ -699,6 +708,7 @@ def test_fuzz(self, thread_id, smic_start, smic_end, _addr, _addr1, scan_mode=Fa break if scan_mode: scan.log_smi_result(self.logger) + scan.update_history_stats() scan.clear() return bad_ptr_cnt, scan