Skip to content

Commit

Permalink
Add auto-breakpoint utility. (#732)
Browse files Browse the repository at this point in the history
* Simplify debug utils.

* Add auto breakpoint.

* Fix stuff.
  • Loading branch information
azteca1998 authored Jul 22, 2024
1 parent 72d10b7 commit fa2fa40
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 96 deletions.
15 changes: 15 additions & 0 deletions src/libfuncs/enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ pub fn build_init<'ctx, 'this>(
metadata: &mut MetadataStorage,
info: &EnumInitConcreteLibfunc,
) -> Result<()> {
#[cfg(feature = "with-debug-utils")]
if let Some(auto_breakpoint) =
metadata.get::<crate::metadata::auto_breakpoint::AutoBreakpoint>()
{
auto_breakpoint.maybe_breakpoint(
entry,
location,
metadata,
&crate::metadata::auto_breakpoint::BreakpointEvent::EnumInit {
type_id: info.signature.branch_signatures[0].vars[0].ty.clone(),
variant_idx: info.index,
},
)
}

let val = build_enum_value(
context,
registry,
Expand Down
1 change: 1 addition & 0 deletions src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::{
collections::{hash_map::Entry, HashMap},
};

pub mod auto_breakpoint;
pub mod debug_utils;
pub mod enum_snapshot_variants;
pub mod gas;
Expand Down
45 changes: 45 additions & 0 deletions src/metadata/auto_breakpoint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#![cfg(feature = "with-debug-utils")]

use super::{debug_utils::DebugUtils, MetadataStorage};
use cairo_lang_sierra::ids::ConcreteTypeId;
use melior::ir::{Block, Location};
use std::collections::HashSet;

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum BreakpointEvent {
EnumInit {
type_id: ConcreteTypeId,
variant_idx: usize,
},
}

#[derive(Clone, Debug, Default)]
pub struct AutoBreakpoint {
events: HashSet<BreakpointEvent>,
}

impl AutoBreakpoint {
pub fn add_event(&mut self, event: BreakpointEvent) {
self.events.insert(event);
}

pub fn has_event(&self, event: &BreakpointEvent) -> bool {
self.events.contains(event)
}

pub fn maybe_breakpoint(
&self,
block: &Block,
location: Location,
metadata: &MetadataStorage,
event: &BreakpointEvent,
) {
if self.has_event(event) {
metadata
.get::<DebugUtils>()
.unwrap()
.debug_breakpoint_trap(block, location)
.unwrap();
}
}
}
155 changes: 59 additions & 96 deletions src/metadata/debug_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,13 @@ pub struct DebugUtils {
}

impl DebugUtils {
pub fn breakpoint_marker<'c, 'a>(
pub fn breakpoint_marker(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::BreakpointMarker) {
module.body().append_operation(func::func(
context,
Expand All @@ -157,17 +154,14 @@ impl DebugUtils {
}

/// Prints the given &str.
pub fn debug_print<'c, 'a>(
pub fn debug_print(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
block: &Block,
message: &str,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::DebugPrint) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -236,29 +230,19 @@ impl DebugUtils {
Ok(())
}

pub fn debug_breakpoint_trap<'c, 'a>(
&mut self,
block: &'a Block<'c>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
pub fn debug_breakpoint_trap(&self, block: &Block, location: Location) -> Result<()> {
block.append_operation(OperationBuilder::new("llvm.intr.debugtrap", location).build()?);
Ok(())
}

pub fn print_pointer<'c, 'a>(
pub fn print_pointer(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintPointer) {
module.body().append_operation(func::func(
context,
Expand All @@ -284,17 +268,14 @@ impl DebugUtils {
Ok(())
}

pub fn print_i1<'c, 'a>(
pub fn print_i1(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintI1) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -322,17 +303,14 @@ impl DebugUtils {
Ok(())
}

pub fn print_felt252<'c, 'a>(
pub fn print_felt252(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintFelt252) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -424,17 +402,14 @@ impl DebugUtils {
Ok(())
}

pub fn print_i8<'c, 'a>(
pub fn print_i8(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintI8) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -462,17 +437,14 @@ impl DebugUtils {
Ok(())
}

pub fn print_i32<'c, 'a>(
pub fn print_i32(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintI32) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -500,17 +472,14 @@ impl DebugUtils {
Ok(())
}

pub fn print_i64<'c, 'a>(
pub fn print_i64(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintI64) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -538,17 +507,14 @@ impl DebugUtils {
Ok(())
}

pub fn print_i128<'c, 'a>(
pub fn print_i128(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
value: Value<'c, '_>,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
block: &Block,
value: Value,
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::PrintI128) {
module.body().append_operation(func::func(
context,
Expand Down Expand Up @@ -610,18 +576,15 @@ impl DebugUtils {
/// Dump a memory region at runtime.
///
/// Requires the pointer (at runtime) and its length in bytes (at compile-time).
pub fn dump_mem<'c, 'a>(
pub fn dump_mem(
&mut self,
context: &'c Context,
context: &Context,
module: &Module,
block: &'a Block<'c>,
ptr: Value<'c, '_>,
block: &Block,
ptr: Value,
len: usize,
location: Location<'c>,
) -> Result<()>
where
'c: 'a,
{
location: Location,
) -> Result<()> {
if self.active_map.insert(DebugBinding::DumpMemRegion) {
module.body().append_operation(func::func(
context,
Expand Down

0 comments on commit fa2fa40

Please sign in to comment.