diff --git a/README.md b/README.md index 3b0b14c3..b62e1646 100644 --- a/README.md +++ b/README.md @@ -1548,7 +1548,7 @@ func (vm *VirtualMachine) UpdateAp(instruction *Instruction, operands *Operands) } ``` -### CairoRunner +#### CairoRunner Now that can can execute cairo steps, lets look at the VM's initialization step. We will begin by creating our `CairoRunner`: @@ -1589,7 +1589,7 @@ func (r *CairoRunner) Initialize() (memory.Relocatable, error) { } ``` -#### InitializeSegments +##### InitializeSegments This method will create our program and execution segments @@ -1602,7 +1602,7 @@ func (r *CairoRunner) initializeSegments() { } ``` -#### initializeMainEntrypoint +##### initializeMainEntrypoint This method will initialize the memory and initial register values to begin execution from the main entrypoint, and return the final pc @@ -1614,7 +1614,7 @@ func (r *CairoRunner) initializeMainEntrypoint() (memory.Relocatable, error) { } ``` -#### initializeFunctionEntrypoint +##### initializeFunctionEntrypoint This method will initialize the memory and initial register values to execute a cairo function given its offset within the program segment (aka entrypoint) and return the final pc. In our case, this function will be the main entrypoint, but later on we will be able to use this method to run starknet contract entrypoints. The stack will then be loaded into the execution segment in the next method. For now, the stack will be empty, but later on it will contain the builtin bases (which are the arguments for the main function), and the function arguments when running a function from a starknet contract. @@ -1631,7 +1631,7 @@ func (r *CairoRunner) initializeFunctionEntrypoint(entrypoint uint, stack *[]mem } ``` -#### InitializeState +##### InitializeState This method will be in charge of loading the program data into the program segment and the stack into the execution segment @@ -1648,7 +1648,7 @@ func (r *CairoRunner) initializeState(entrypoint uint, stack *[]memory.MaybeRelo } ``` -#### initializeVm +##### initializeVm This method will set the values of the VM's `RunContext` with our `CairoRunner`'s initial values @@ -1662,7 +1662,7 @@ func (r *CairoRunner) initializeVM() { With `CairoRunner.Initialize()` now complete we can move on to the execution step: -#### RunUntilPc +##### RunUntilPc This method will continuously execute cairo steps until the end pc, returned by 'CairoRunner.Initialize()' is reached @@ -2438,7 +2438,7 @@ func (vm *VirtualMachine) ComputeOperands(instruction Instruction) (Operands, er With all of our builtin logic integrated into the codebase, we can implement any builtin and use it in our cairo programs while worrying only about implementing the `BuiltinRunner` interface and creating the builtin in the `NewCairoRunner` function. -##### RangeCheck +#### RangeCheck The `RangeCheck` builtin does a very simple thing: it asserts that a given number is in the range $[0, 2^{128})$, i.e., that it's greater than zero and less than $2^{128}$. This might seem superficial but it is used for a lot of different things in Cairo, including comparing numbers. Whenever a program asserts that some number is less than other, the range check builtin is being called underneath. @@ -2517,7 +2517,7 @@ func (r *RangeCheckBuiltinRunner) AddValidationRule(mem *memory.Memory) { } `````` -##### Output +#### Output TODO @@ -2968,5 +2968,881 @@ TODO #### Hints +So far we have been thinking about the VM mostly abstracted from the prover and verifier it's meant to feed its results to. The last main feature we need to talk about, however, requires keeping this proving/verifying logic in mind. +As a reminder, the whole point of the Cairo VM is to output a trace/memory file so that a `prover` can then create a cryptographic proof that the execution of the program was done correctly. A `verifier` can then take that proof and verify it in much less time than it would have taken to re-execute the entire program. +In this model, the one actually using the VM to run a cairo program is *always the prover*. The verifier does not use the VM in any way, as that would defeat the entire purpose of validity proofs; they just get the program being run and the proof generated by the prover and run some cryptographic algorithm to check it. + +While the verifier does not execute the code, they do *check it*. As an example, if a cairo program computes a fibonacci number like this: + +``` +func main() { + // Call fib(1, 1, 10). + let result: felt = fib(1, 1, 10); +} +``` + +the verifier won't *run* this, but they will reject any incorrect execution of the call to `fib`. The correct value for `result` in this case is `144` (it's the 10th fibonacci number); any attempt by the prover to convince the verifier that `result` is not `144` will fail, because the call to the `fib` function is *being proven* and thus *seen* by the verifier. + +A `Hint` is a piece of code that is not proven, and therefore not seen by the verifier. If `fib` above were a hint, then the prover could convince the verifier that `result` is $144$, $0$, $1000$ or any other number. + +In cairo 0, hints are code written in `Python` and are surrounded by curly brackets. Here's an example from the `alloc` function, provided by the Cairo common library + +``` +func alloc() -> (ptr: felt*) { + %{ memory[ap] = segments.add() %} + ap += 1; + return (ptr=cast([ap - 1], felt*)); +} +``` + +The first line of the function, + +``` +%{ memory[ap] = segments.add() %} +``` + +is a hint called `ADD_SEGMENT`. All it does is create a new memory segment, then write its base to the current value of `ap`. This is python code that is being run in the context of the VM's execution; thus `memory` refers to the VM's current memory and `segments.add()` is just a function provided by the VM to allocate a new segment. + +At this point you might be wondering: why run code that's not being proven? Isn't the whole point of Cairo to prove correct execution? There are (at least) two reasons for hints to exist. + +##### Nothing to prove + +For some operations there's simply nothing to prove, as they are just convenient things one wants to do during execution. The `ADD_SEGMENT` hint shown above is a good example of that. When proving execution, the program's memory is presented as one relocated continuous segment, it does not matter at all which segment a cell was in, or when that segment was added. The verifier doesn't care. + +Because of this, there's no reason to make `ADD_SEGMENT` a part of the cairo language and have an instruction for it. + +##### Optimization + +Certain operations can be very expensive, in the sense that they might involve a huge amount of instructions or memory usage, and therefore contribute heavily to the proving time. For certain calculations, there are two ways to convince the verifier that it was done correctly: + +- Write the entire calculation in Cairo/Cairo Assembly. This makes it show up in the trace and therefore get proven. +- *Present the result of the calculation to the verifier through a hint*, then show said result indeed satisfies the relevant condition that makes it the actual result. + +To make this less abstract, let's show two examples. + +##### Square root + +Let's say the calculation in question is to compute the square root of a number `x`. The two ways to do it then become: + +- Write the usual square root algorithm in Cairo to compute `sqrt(x)`. +- Write a hint that computes `sqrt(x)`, then immediately after calling the hint show __in Cairo__ that `(sqrt(x))^2 = x`. + +The second approach is exactly what the `sqrt` function in the Cairo common library does: + +``` +// Returns the floor value of the square root of the given value. +// Assumptions: 0 <= value < 2**250. +func sqrt{range_check_ptr}(value) -> felt { + alloc_locals; + local root: felt; + + %{ + from starkware.python.math_utils import isqrt + value = ids.value % PRIME + assert value < 2 ** 250, f"value={value} is outside of the range [0, 2**250)." + assert 2 ** 250 < PRIME + ids.root = isqrt(value) + %} + + assert_nn_le(root, 2 ** 125 - 1); + tempvar root_plus_one = root + 1; + assert_in_range(value, root * root, root_plus_one * root_plus_one); + + return root; +} +``` + +If you read it carefully, you'll see that the hint in this function computes the square root in python, then this line + +``` +assert_in_range(value, root * root, root_plus_one * root_plus_one); +``` + +asserts __in Cairo__ that `(sqrt(x))^2 = x`. + +This is done this way because it is much cheaper, in terms of the generated trace (and thus proving time), to square a number than compute its square root. + +Notice that the last assert is absolutely mandatory to make this safe. If you forget to write it, the square root calculation does not get proven, and anyone could convince the verifier that the result of `sqrt(x)` is any number they like. + +##### Linear search turned into an O(1) lookup + +This example is taken from the [Cairo documentation](https://docs.cairo-lang.org/0.12.0/hello_cairo/program_input.html). + +Given a list of `(key, value)` pairs, if we want to write a `get_value_by_key` function that returns the value associated to a given key, there are two ways to do it: + +- Write a linear search in Cairo, iterating over each key until you find the requested one. +- Do that exact same linear search *inside a hint*, find the result, then show that the result's key is the one requested. + +Again, the second approach makes the resulting trace and proving much faster, because it's just a lookup; there's no linear search. Notice this only applies to proving, the VM has to execute the hint, so there's still a linear search when executing to generate the trace. In fact, the second approach is more expensive for the VM than the first one. It has to do both a linear search and a lookup. This is a tradeoff in favor of proving time. + +Also note that, as in the square root example, when writing this logic you need to remember to show the hint's result is the correct one in Cairo. If you don't, your code is not being proven. + +##### Non-determinism + +The Cairo paper and documentation refers to this second approach to calculating things through hints as *non-determinism*. The reason for this is that sometimes there is more than one result that satisfies a certain condition. This means that cairo execution becomes non deterministic; a hint could output multiple values, and in principle there is no way to know which one it's going to be. Running the same code multiple times could give different results. + +The square root is an easy example of this. The condition `(sqrt(x))^2 = x` is not unique, there are two solutions to it. Without the hint, this is non-deterministic, `x` could have multiple values; the hint resolves that by choosing a specific value when being run. + +##### Common Library and Hints + +As explained above, using hints in your code is highly unsafe. Forgetting to add a check after calling them can make your code vulnerable to any sorts of attacks, as your program will not prove what you think it proves. + +Because of this, most hints in Cairo 0 are wrapped around or used by functions in the Cairo common library that do the checks for you, thus making them safe to use. Ideally, Cairo developers should not be using hints on their own; only transparently through Cairo library functions they call. + +##### Whitelisted Hints + +In Cairo, a hint could be any Python code you like. In the context of it as just another language someone might want to use, this is fine. In the context of Cairo as a programming language used to write smart contracts deployed on a blockchain, it's not. Users could deploy contracts with hints that simply do + +```python +while true: + pass +``` + +and grind the network down to a halt, as nodes get stuck executing an infinite loop when calling the contract. + +To address this, the starknet network maintains a list of *whitelisted* hints, which are the only ones that can be used in starknet contracts. These are the ones implemented in this VM. + +#### Implementing Hints + +Hints are essentially logic that is executed in each cairo step, before the next instruction, and which may interact with and modify the vm. We will first look into the broad execution loop and the dive into the different types of interaction hints can have with the vm. +While the original cairo-lang implementation executes these hints in python, we will instead be implementing their logic in go and matching each string of python code to a function in the vm's code. We will also be using an interface to abstract the hint processing part of the vm and allow greater flexibility when using the vm in other contexts. + +##### The HintProcessor interface + +This `HintProcessor` interface will consist of two methods: `CompileHint`, which receives hint data from the compiled program and transforms it into whatever format is more convenient for hint execution, and `ExecuteHint`, which will receive this data and use it to execute the hint. + +```go +type HintProcessor interface { + // Transforms hint data outputted by the VM into whichever format will be later used by ExecuteHint + CompileHint(hintParams *parser.HintParams, referenceManager *parser.ReferenceManager) (any, error) + // Executes the hint which's data is provided by a dynamic structure previously created by CompileHint + ExecuteHint(vm *VirtualMachine, hintData *any, constants *map[string]lambdaworks.Felt, execScopes *types.ExecutionScopes) error +} +``` + +We will first look at how hint processing ties into the core vm execution loop, and then look into how this vm's implementation of the `HintProcessor` interface works: + +##### VM execution loop + +Before we begin executing steps, we will feed the hint-related information from the compiled program to the `HintProcessor`, and obtain what we call `HintData`, which will be later on used to execute the hint. As we can see, the compiled json stores the hint information in a map which connects pc offsets (at which pc offset the hint should be executed) to a list of hints (yes, more than one hint can be executed as a given pc), and we will use a similar structure to hold the compiled `HintData`. +```go +func (r *CairoRunner) BuildHintDataMap(hintProcessor vm.HintProcessor) (map[uint][]any, error) { + hintDataMap := make(map[uint][]any) + for pc, hintsParams := range r.Program.Hints { + hintDatas := make([]any, 0, len(hintsParams)) + for _, hintParam := range hintsParams { + data, err := hintProcessor.CompileHint(&hintParam) + if err != nil { + return nil, err + } + hintDatas = append(hintDatas, data) + } + hintDataMap[pc] = hintDatas + } + + return hintDataMap, nil +} +``` + +Once we have our map of `HintData`s we can start executing cairo steps. Before fetching the next instruction, we will check if we have hints to run for the current pc, and if we do, the `HintProcessor` will execute each hint using the corresponding `HintData`. + +```go +func (v *VirtualMachine) Step(hintProcessor HintProcessor, hintDataMap *map[uint][]any) error { + // Run Hint + hintDatas, ok := (*hintDataMap)[v.RunContext.Pc.Offset] + if ok { + for i := 0; i < len(hintDatas); i++ { + err := hintProcessor.ExecuteHint(v, &hintDatas[i]) + if err != nil { + return err + } + } + } + + // Run Instruction + encoded_instruction, err := v.Segments.Memory.Get(v.RunContext.Pc) +``` + +##### Implementing a HintProcessor: ExecuteHint + +This method will receive a `HintData`, and match its `Code` field, which contains the python code as a string, to a go function that implements its logic: + +```go +func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any) error { + data, ok := (*hintData).(HintData) + if !ok { + return errors.New("Wrong Hint Data") + } + switch data.Code { + case ADD_SEGMENT: + return addSegment(vm) + default: + return errors.Errorf("Unknown Hint: %s", data.Code) + } +} +``` + +Where `ADD_SEGMENT` is a constant with the python code of the hint + +```go +const ADD_SEGMENT = "memory[ap] = segments.add()" +``` + +And the function `addSegment` implements its logic, which is to add a segment to the vm's memory: + +```go +func addSegment(vm *VirtualMachine) error { + newSegmentBase := vm.Segments.AddSegment() + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableRelocatable(newSegmentBase)) +} +``` + +Before we implement the `CompileHint` method, lets look at this crucial part of hint interaction with the vm: + +##### Hint Interaction: Ids + +Ids are hints' way to interact with variables in a cairo program. For example, if I declare a variable `n` in my cairo code, I can access that `n` variable inside a hint using `ids.n`. + +The following cairo snippet would print the number 17 + +```py + let n = 17 + %{ print(ids.n) %} +``` + +To access these variables when implementing our hints in go we will implement the `IdsManager`, which will allow us to read and write cairo variables. +But interacting with cairo variables is not as easy as it sounds. In order to access them, we must first compute their address from a `Reference` + +###### References + +As cairo variables are created during the vm's execution, we can't know their value beforehand. In order to solve this, the compiled program provides us with references for cairo variables available to hints. These references are instructions on where we can find a specific cairo variable in memory. For example, they might tell us to take the current value of the fp register, substract 1 from it, and access the memory value at that new address. + +As these references come in string format, we need to parse them into a struct that we can efficiently use to compute addresses: + +```go +type HintReference struct { + Offset1 OffsetValue + Offset2 OffsetValue + Dereference bool + ApTrackingData parser.ApTrackingData + ValueType string +} +``` + +This struct matches the canonical string format for references: `"cast(Offset1 + Offset2, ValueType)"` (or `"[cast(Offset1 + Offset2, ValueType)]"`, in the case of Dereference being true ). +The first two fields: Offset1 and Offset2 will lead us to a particular memory value, the Dereference field will tell us if the value of the ids is that memory value we found (in case of false), or if we should use that value as an address to fetch the ids value from memory (in case of true), and the ValueType tells us what type the variable has (be it a felt, felt*, struct, etc). As we already know the context of the hints, we can ignore the ValueType. + +Now lets look at what an `OffsetValue` is: + +```go +type OffsetValue struct { + ValueType offsetValueType + Immediate Felt + Value int + Register vm.Register + Dereference bool +} + +type offsetValueType uint + +const ( + Value offsetValueType = 0 + Immediate offsetValueType = 1 + Reference offsetValueType = 2 +) +``` + +There are three types of `OffsetValue`: + +* Inmediate: Contains the value of the ids as a literal, for example `"cast(17, felt)"` is a reference to a felt with literal value 17. Only Offset1 can be of Immediate type, and the reference can't have Dereference = true + +* Reference: It is made up of a Register (AP or FP) and a Value, it will tell us the location of an ids in memory by pointing to a memory cell relative to a register. For example `"cast(fp + (-1), felt*)"` is a reference with Offset1 of type Reference, with register FP and Value -1, and it leads us to an felt* value obtained from subtracting 1 from the current fp value. OffsetValues of type Reference can also have Dereference, for example: `"cast([fp + (-1)], felt)"` will lead us to a felt value located one cell before the one at the current register value. Both OffsetValues can be of type Reference in the same Reference + +* Value: Only Offset2 can be of type value, it consists of a single field value and acts as a modifier to the first OffsetValue (which will always be of type Reference for this case). For example, we can add second OffsetValue of Value type with Value = 1 to the first Reference type example: `"cast(fp + (-1) + 2), felt*)"`, this will tell us to subtract 1 from fp, and then add 2 to it, and that will be our ids value. + +When an offset doesn't exist in the reference, we use an OffsetValue of type Value with Value 0, which essentially does nothing, to represent it. This allows us to use go's zero value by default to make our code (and life) a bit simpler. + +This can be a bit hard to grasp at first so lets look at some examples: + +* Immediate Reference + String Reference: `cast(17, felt)` + Struct Reference: {Offset1: {ValueType: Immediate, Immedate: 17}, ValueType: "felt"} + Reference in words: The value of the ids is 17 + +* Dereference with one offset of Type Reference + String Reference: `[cast(ap + 1, felt)]` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1}, Dereference: true, ValueType: "felt"} + Reference in words: Take the current value of ap, add 1 to it and then fetch the memory value at that address + +* Two offsets of type Reference, Value + String Reference: `"cast(ap + 1 + (-2), felt*)"` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1}, Offset2: {ValueType: Value, Value: -2}, ValueType: "felt*"} + Reference in words: Take the current value of ap, add 1 to it and then subtract 2 from it + +* Two offsets of type Reference (with Dereference), Value + String Reference: `"cast([ap + (-1)] + (-2), felt*)"` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1, Dereference: true}, Offset2: {ValueType: Value, Value: -2}, ValueType: "felt*"} + Reference in words: Take the current value of ap, add 1 to it, fetch the memory value at that address and then subtract 2 from it + +* Two offsets of type Reference (with Dereference) + String Reference: `"cast([ap + (-1)] + [ap], felt)"` + Struct Reference: {Offset1: {ValueType: Reference, Register: AP, Value: 1, Dereference: true}, Offset2: {ValueType: Reference, Register: AP, Value: 0, Dereference: true}, ValueType: "felt*"} + Reference in words: Take the current value of ap, subtract 1 to it, fetch the memory value at that address. Take the current value of ap, fetch the memory value at that address. Add the two values we obtained. + +Now all thats left to analyze is in the reference is the `ApTracking`: + +```go +type ApTrackingData struct { + Group int `json:"group"` + Offset int `json:"offset"` +} +``` + +As the value of AP is constantly changing with each instruction executed, its not that simple to track variables who's references are based on ap. ApTracking is used to calculate the difference between the value of ap at the moment the variable was created/ enterted the scope of the function (and hence, the hint) and the value of ap at the moment the hint is executed. Each hint and each reference has its own ApTracking. + +###### Computing addresses using References + +The function used to fetch the value from an ids variable using a reference works as follows: +1. Check if the refeference has type Immediate, if this is true, return the Immediate field +2. Calculate the address of the ids variable using the reference (we will see how this works soon) +3. Check the Dereference field of the reference, if false, return the address we obtained in 2, if true, fetch the memory value at that address and return it. + +```go +func getValueFromReference(reference *HintReference, apTracking parser.ApTrackingData, vm *VirtualMachine) (*MaybeRelocatable, bool) { + // Handle the case of immediate + if reference.Offset1.ValueType == Immediate { + return NewMaybeRelocatableFelt(reference.Offset1.Immediate), true + } + addr, ok := getAddressFromReference(reference, apTracking, vm) + if ok { + if reference.Dereference { + val, err := vm.Segments.Memory.Get(addr) + if err == nil { + return val, true + } + } else { + return NewMaybeRelocatableRelocatable(addr), true + } + } + return nil, false +} +``` + +In order to extract the value of an ids variable, we will first compute its address, this works as follows: +1. Check that the Offset1 is a Reference +2. Compute the value of Offset1 +3. Add the value of Offet2. By either calculating it in the case of a Reference type, or just using the Value field in the case of a Value type. +4. Return the result obtained in step 3. +```go +func getAddressFromReference(reference *HintReference, apTracking parser.ApTrackingData, vm *VirtualMachine) (Relocatable, bool) { + if reference.Offset1.ValueType != Reference { + return Relocatable{}, false + } + offset1 := getOffsetValueReference(reference.Offset1, reference.ApTrackingData, apTracking, vm) + if offset1 != nil { + offset1_rel, is_rel := offset1.GetRelocatable() + if is_rel { + switch reference.Offset2.ValueType { + case Reference: + offset2 := getOffsetValueReference(reference.Offset2, reference.ApTrackingData, apTracking, vm) + if offset2 != nil { + res, err := offset1_rel.AddMaybeRelocatable(*offset2) + if err == nil { + return res, true + } + } + case Value: + res, err := offset1_rel.AddInt(reference.Offset2.Value) + if err == nil { + return res, true + } + } + } + } + return Relocatable{}, false + +} +``` + +Now lets see how computing the value of an OffsetValue of type Reference works: +1. Determine a base address by checking the Register field of the OffsetValue. If the register is FP, use the current value of fp. If the register is AP, apply the necessary ap tracking corrections to ap and use it as base address. +2. Add the field Value of the OffsetValue to the base address +3. Check the Dereference field of the OffsetValue. If its false, return the address we obtained in 2. If is true, fetch the memory value at that address and return it + +```go +func getOffsetValueReference(offsetValue OffsetValue, refApTracking parser.ApTrackingData, hintApTracking parser.ApTrackingData, vm *VirtualMachine) *MaybeRelocatable { + var baseAddr Relocatable + ok := true + switch offsetValue.Register { + case FP: + baseAddr = vm.RunContext.Fp + case AP: + baseAddr, ok = applyApTrackingCorrection(vm.RunContext.Ap, refApTracking, hintApTracking) + } + if ok { + baseAddr, err := baseAddr.AddInt(offsetValue.Value) + if err == nil { + if offsetValue.Dereference { + // val will be nil if err is not nil, so we can ignore it + val, _ := vm.Segments.Memory.Get(baseAddr) + return val + } else { + return NewMaybeRelocatableRelocatable(baseAddr) + } + } + } + return nil +} +``` + +Finally, the last thing we need is to know how ap tracking corrections work. +This function will receive an address (the current value of ap), the ap tracking data of the reference (unique to each reference) and the hint's ap tracking data (unique to each hint) and perform the following steps: +1. Assert that both ap tracking datas belong to the same group (aka their Group fields match) +2. Subtract the difference between the hint's ap tracking data's Offset field and the reference's ap tracking data's Offset field from the address (ap) +3. Return the value obtained in 2 + +```go +func applyApTrackingCorrection(addr Relocatable, refApTracking parser.ApTrackingData, hintApTracking parser.ApTrackingData) (Relocatable, bool) { + // Reference & Hint ApTracking must belong to the same group + if refApTracking.Group == hintApTracking.Group { + addr, err := addr.SubUint(uint(hintApTracking.Offset - refApTracking.Offset)) + if err == nil { + return addr, true + } + } + return Relocatable{}, false +} +``` + +###### Implement the IdsManager + +Now that we have tackled reference management, we can implement the `IdsManager`, which will allow us to "forget" what references are when implementing hints. + +The IdsManager has the following structure: + +* References: A map of all the ids variables the hint has access to, it maps the name of the cairo varaible to a HintReference (the parsed version of the compiled program's Reference) +* HintAptracking: The ap tracking data unique to the hint + +```go +type IdsManager struct { + References map[string]HintReference + HintApTracking parser.ApTrackingData +} +``` + +And we can also implement friendlier versions of the functions we implemented in the previous section, that take the name of the ids variable, instead of the reference and hint ap tracking data: + +```go +// Returns the value of an identifier as a MaybeRelocatable +func (ids *IdsManager) Get(name string, vm *VirtualMachine) (*MaybeRelocatable, error) { + reference, ok := ids.References[name] + if ok { + val, ok := getValueFromReference(&reference, ids.HintApTracking, vm) + if ok { + return val, nil + } + } + return nil, ErrUnknownIdentifier(name) +} + +// Returns the address of an identifier given its name +func (ids *IdsManager) GetAddr(name string, vm *VirtualMachine) (Relocatable, error) { + reference, ok := ids.References[name] + if ok { + addr, ok := getAddressFromReference(&reference, ids.HintApTracking, vm) + if ok { + return addr, nil + } + } + return Relocatable{}, ErrUnknownIdentifier(name) +} +``` + +We can also make more specialized versions of the Get method, that will also handle conversions to Felt or Relocatable, as we will almost always know which type of value we are expecting when implementing hints: + +```go +// Returns the value of an identifier as a Felt +func (ids *IdsManager) GetFelt(name string, vm *VirtualMachine) (lambdaworks.Felt, error) { + val, err := ids.Get(name, vm) + if err != nil { + return lambdaworks.Felt{}, err + } + felt, is_felt := val.GetFelt() + if !is_felt { + return lambdaworks.Felt{}, ErrIdentifierNotFelt(name) + } + return felt, nil +} + +// Returns the value of an identifier as a Relocatable +func (ids *IdsManager) GetRelocatable(name string, vm *VirtualMachine) (Relocatable, error) { + val, err := ids.Get(name, vm) + if err != nil { + return Relocatable{}, err + } + relocatable, is_relocatable := val.GetRelocatable() + if !is_relocatable { + return Relocatable{}, errors.Errorf("Identifier %s is not a Relocatable", name) + } + return relocatable, nil +} +``` + +Lastly, we can also implement a method to insert a value into an ids variable (as we already know how to calculate their address) + +```go +// Inserts value into memory given its identifier name +func (ids *IdsManager) Insert(name string, value *MaybeRelocatable, vm *VirtualMachine) error { + + addr, err := ids.GetAddr(name, vm) + if err != nil { + return err + } + return vm.Segments.Memory.Insert(addr, value) +} +``` + +##### Implementing a HintProcessor: CompileHint + + +The `CompileHint` method will be in charge of converting the hint-related data from the compiled json into a format that our processor can use to execute each hint. For our `CairoVmHintProcessor` we will use the following struct: + +```go +type HintData struct { + Ids IdsManager + Code string +} +``` +Where IdsManager is the struct we just saw in the previous section, a struct which manages all kinds of interaction between the hint implemented in go and the cairo variables available to it, and Code is the python code of the hint. + +And we will implement a `CompileHint` method which receives the hint's data from the compiled program in the form of `HintParams`, and a reference to the compiled json's `ReferenceManager`, a list of references to all ids variables in the program. And performs the following steps: + +1. Create a map from variable name to HintReference +2. Iterate over the hintParams's `ReferenceIds` field (a map from an ids name to an index in the ReferenceManager). For each iteration: + + 1. Remove the path from the reference's name (shortening full paths such a "__main__.a" to just the variable name "a"), + 2. Fetch the reference from the ReferenceManager (using the index from the ReferenceIds) + 3. Parse the Reference into a `HintReference` + 4. Insert the parsed reference into the map we created in 1, using the shortened name (from 2.1) as a key + +3. Create an IdsManager using the map from 1, and the hintParam's ap tracking data +4. Create a `HintData` struct with the IdsManager and the hintParam's Code + +```go +func (p *CairoVmHintProcessor) CompileHint(hintParams *parser.HintParams, referenceManager *parser.ReferenceManager) (any, error) { + references := make(map[string]HintReference, 0) + for name, n := range hintParams.FlowTrackingData.ReferenceIds { + if int(n) >= len(referenceManager.References) { + return nil, errors.New("Reference not found in ReferenceManager") + } + split := strings.Split(name, ".") + name = split[len(split)-1] + references[name] = ParseHintReference(referenceManager.References[n]) + } + ids := NewIdsManager(references, hintParams.FlowTrackingData.APTracking) + return HintData{Ids: ids, Code: hintParams.Code}, nil +} +``` + +##### Hint Interaction: Constants + +###### How are Constants handled by hints and the cairo compiler + +Hints can also access constant variables using the ids syntax, for example, a hint can access the `MAX_SIZE` constant from a cairo program using `ids.MAX_SIZE`. While the behaviour from the hint's standpoint is identical to regular ids variables, they are handled differently by both the compiler and the vm. + +They are part of the compiled program's `Idenfifiers` field, and can be identified by the `const` type. We may also find aliases for them in the `Identifiers` section, aliases happen when a cairo file imports constants from another cairo file, in such cases we will have an identifier of type `const` under the file where the constant was declared's path, and an identifier of type `alias` under the file where the constant was imported's path, pointing to the original constant's identifier. For example: + +```json +"starkware.cairo.common.cairo_keccak.keccak.BLOCK_SIZE": { + "destination": "starkware.cairo.common.cairo_keccak.packed_keccak.BLOCK_SIZE", + "type": "alias" + }, +"starkware.cairo.common.cairo_keccak.packed_keccak.BLOCK_SIZE": { + "type": "const", + "value": 3 + }, +``` + +This is an extract from a compiled cairo program, where we can see that there is a constant `BLOCK_SIZE`, with value 3, declared in packed_keccak.cairo file, that was then imported by the keccak.cairo file. + +###### How does the vm extract the constants for hint execution + +As constants are not unique to any specific hint, they are not provided to the HintProcessor's `CompileHint` method, but are instead provided directly to the `ExecuteHint` method. Before providing these constants, we need to first extract them from the Identifiers field of the compiled program. This works as follows: + +1. Create a map to store the constants, maping full path constant names to their Felt value +2. Iterate over the program's `Identifiers` field, and check the type of each identifier. If the identifier is of type `const`, add its value to the map created in 1. If the identifier is of type `alias`, search for the identifier at its destination (we will see how to do this next), and if its of type `const`, add it to the map created in 1 under the alias' name. +3. Return the map created in 1 + +```go +func (p *Program) ExtractConstants() map[string]lambdaworks.Felt { + constants := make(map[string]lambdaworks.Felt) + for name, identifier := range p.Identifiers { + switch identifier.Type { + case "const": + constants[name] = identifier.Value + case "alias": + val, ok := searchConstFromAlias(identifier.Destination, &p.Identifiers) + if ok { + constants[name] = val + } + } + } + return constants +} +``` + +In order to search for the aliased identifier we need to do so recursively, as constants can be imported form file A into file B, then from file B into file C and so on. +To do so we use a recursive function which receives the destination field of an alias type identifier and a reference to the identifiers map. It will then look for the identifier using the received destination. If the new identifier is a constant, it wil return its value, if it is an alias it will call itself again with the new alias' destintation, and if its none, it will return false, indicating that the alias was not pointing to a constant. + +```go +func searchConstFromAlias(destination string, identifiers *map[string]Identifier) (lambdaworks.Felt, bool) { + identifier, ok := (*identifiers)[destination] + if ok { + switch identifier.Type { + case "const": + return identifier.Value, true + case "alias": + return searchConstFromAlias(identifier.Destination, identifiers) + } + } + return lambdaworks.Felt{}, false +} +``` + +###### How does the IdsManager handle constants + +Before looking into how the IdsManager handles constants, we'll have to add a new field to it: + +```go +type IdsManager struct { + References map[string]HintReference + HintApTracking parser.ApTrackingData + AccessibleScopes []string +} +``` +AccessibleScopes is a list of paths that a hint has access to, for example, if we were to write a hint in a function `foo` of a cairo program called `program`, that hint's accessible scopes will look something along the likes of `["program", "program.foo"]`. This list is taken directly from the `HintParams`' `AccessibleScopes` field in the compiled json. + +We can use this accessible scopes to determine the correct path for a cairo constant when implementing a hint. To do so, we will be searching for a constant in the map of constants provided by the vm, using the name of the constant in the hint and the possible paths in the accessible scopes, going from innermost (in the example, "program.foo"), to outermost (in the example, "program"). +We will be adding this behaviour to the `IdsManager`, by adding a function that will return the value of a constant given its name (without its full path) and the map of constants, following these steps: + +1. Iterate over the list of accessible scopes in reverse order +2. For each path in accessible scopes, append the name of the constant to get the full-path constant's name +3. Using the full-path constant names, try to fetch from the constants map +4. Once a match is found, return the value from the constant map + +```go +func (ids *IdsManager) GetConst(name string, constants *map[string]lambdaworks.Felt) (lambdaworks.Felt, error) { + // Hints should always have accessible scopes + if len(ids.AccessibleScopes) != 0 { + // Accessible scopes are listed from outer to inner + for i := len(ids.AccessibleScopes) - 1; i >= 0; i-- { + constant, ok := (*constants)[ids.AccessibleScopes[i]+"."+name] + if ok { + return constant, nil + } + } + } + return lambdaworks.FeltZero(), errors.Errorf("Missing constant %s", name) +} +``` + +##### Hint Interaction: ExecutionScopes + +Up until now we saw how hints can interact with the vm and the cairo variables, but what about the interaction between hints themselves? +To answer this question, we will introduce the concept of `Execution Scopes`, they consist of a stack of dictionaries that can hold any kind of variables. These scopes are accessible to all hints, allowing data to be shared between hints without the cairo program being aware of them. As it consists of a stack of dictionaries (from now on referred to as scopes), hints will only be able to interact with the last (or top level) scope. Hints can also remove and create new scopes, we will call these operations `ExitScope` and `EnterScope`. To better illustrate this behaviour, lets make a generic example: + +* HINT A: Adds variable n = 3 (Scopes = [{n: 3}]) +* HINT B: Fetches variable n and updates its value to 5 (Scopes = [{n: 5}]) +* HINT C: Uses the EnterScope operation (Scopes = [{n: 5}, {}]) +* HINT D: Adds variable n = 3 (Scopes = [{n: 5}, {n: 3}]) +* HINT E: Prints the value of n (3), then used the ExitScope operation (Scopes = [{n: 5}]) +* HINT F: Prints the value of n (5) + +Now that we know how execution scopes work, implementing them is quite simple: + +```go +type ExecutionScopes struct { + data []map[string]interface{} +} +``` +We have a stack (represented as a slice), of maps that connect a varaible's name, to its value, accepting any kind of variables as value + +We should also note that when creating an `ExecutionScopes`, it comes with one initial scope (called main scope), which can't be exited + +```go +func NewExecutionScopes() *ExecutionScopes { + data := make([]map[string]interface{}, 1) + data[0] = make(map[string]interface{}) + return &ExecutionScopes{data} +} +``` + +With this struct we can implement the basic operations: + +*EnterScope* + +Adds a new scope to the stack, which is received by the method +```go +func (es *ExecutionScopes) EnterScope(newScopeLocals map[string]interface{}) { + es.data = append(es.data, newScopeLocals) + +} +``` + +*ExitScope* + +Removes the last scope from the stack, guards that the main scope is not removed by the operation. + +```go +func (es *ExecutionScopes) ExitScope() error { + if len(es.data) < 2 { + return ErrCannotExitMainScop + } + es.data = es.data[len(es.data) - 1] + + return nil +} +``` + +*AssignOrUpdateVariable* + +Inserts a variable to the current scope (aka the top one in the stack), overwitting the previous value if it exists + +```go +func (es *ExecutionScopes) AssignOrUpdateVariable(varName string, varValue interface{}) { + locals, err := es.getLocalVariablesMut() + if err != nil { + return + } + (*locals)[varName] = varValue +} +``` + +*Get* + +Fetches a variable from the current scope +```go +func (es *ExecutionScopes) Get(varName string) (interface{}, error) { + locals, err := es.GetLocalVariables() + if err != nil { + return nil, err + } + val, prs := locals[varName] + if !prs { + return nil, ErrVariableNotInScope(varName) + } + return val, nil +} +``` + +*DeleteVariable* + +Removes a variable from the current scope + +```go +func (es *ExecutionScopes) DeleteVariable(varName string) { + locals, err := es.getLocalVariablesMut() + if err != nil { + return + } + delete(*locals, varName) +} +``` + +And the helper methods for these methods: + +```go +func (es *ExecutionScopes) getLocalVariablesMut() (*map[string]interface{}, error) { + locals, err := es.GetLocalVariables() + if err != nil { + return nil, err + } + return &locals, nil +} + +func (es *ExecutionScopes) GetLocalVariables() (map[string]interface{}, error) { + if len(es.data) > 0 { + return es.data[len(es.data)-1], nil + } + return nil, ExecutionScopesError(errors.Errorf("Every enter_scope() requires a corresponding exit_scope().")) +} +``` + +##### Hint Implementation Examples + +Now that we have all the necessary tools to begin implementing hints, lets look at some examples: + +###### IS_LE_FELT + +The python code we have to implement is the following: + +"memory[ap] = 0 if (ids.a % PRIME) <= (ids.b % PRIME) else 1" + +The first thing we notice is that its uses the ids variables "a" and "b" so this gives as an opportunity to use our `IdsManager`. We can also look at the context of this hint, in this case the common library function is_le_felt (in the math_cmp module) to see that ids.a and ids.b are both felt values. + +We can divide the hint into the following steps: + +1. Fetch ids.a as a Felt +2. Fetch ids.b as a Felt +3. Compare the values of a and b (we don't need to perform % PRIME, as our Felt type already takes care of it) +4. Insert either 0 or 1 at the current value of ap depending on the comparison in 3 + +And implement the hint: + +```go +func isLeFelt(ids IdsManager, vm *VirtualMachine) error { + a, err := ids.GetFelt("a", vm) + if err != nil { + return err + } + b, err := ids.GetFelt("b", vm) + if err != nil { + return err + } + if a.Cmp(b) != 1 { + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero())) + } + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne())) +} +``` + +###### ASSERT_LE_FELT_EXCLUDED_0 + +The python code we have to implement is the following: + +"memory[ap] = 1 if excluded != 0 else 0" + +This hint is quite similar to the previous example, except that instead of comparing ids variables it uses this "excluded" variable. As this variable is neither an ids, nor is it created during the hint, we can tell that it is a variable created by a previous hint, shared through the current execution scope. With this knowledge, we can divide the hint into the following set of steps: + +1. Fetch excluded from the execution scopes +2. Cast the excluded variable to a concrete type. In this case, as we have previously implemented the hint that creates this variable, we know its type is 'int' +3. Compare the values of excluded vs 0 +4. Insert either 0 or 1 at the current value of ap depending on the comparison in 3 + +```go +func assertLeFeltExcluded0(vm *VirtualMachine, scopes *ExecutionScopes) error { + // Fetch scope var + excludedAny, err := scopes.Get("excluded") + if err != nil { + return err + } + excluded, ok := excludedAny.(int) + if !ok { + return errors.New("excluded not in scope") + } + if excluded == 0 { + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltZero())) + } + return vm.Segments.Memory.Insert(vm.RunContext.Ap, NewMaybeRelocatableFelt(FeltOne())) +} +``` + +#### Proof mode + +TODO + +#### Temporary Segments + +TODO diff --git a/cairo_programs/cairo_keccak.cairo b/cairo_programs/cairo_keccak.cairo index 8adcd515..b5575e7c 100644 --- a/cairo_programs/cairo_keccak.cairo +++ b/cairo_programs/cairo_keccak.cairo @@ -27,3 +27,4 @@ func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { return (); } + diff --git a/cairo_programs/ec_double_assign.cairo b/cairo_programs/ec_double_assign.cairo new file mode 100644 index 00000000..16419e72 --- /dev/null +++ b/cairo_programs/ec_double_assign.cairo @@ -0,0 +1,32 @@ +%builtins range_check +from starkware.cairo.common.cairo_secp.bigint import BigInt3, nondet_bigint3 +struct EcPoint { + x: BigInt3, + y: BigInt3, +} + +func ec_double{range_check_ptr}(point: EcPoint, slope: BigInt3) -> (res: BigInt3) { + %{ + from starkware.cairo.common.cairo_secp.secp_utils import pack + SECP_P = 2**255-19 + + slope = pack(ids.slope, PRIME) + x = pack(ids.point.x, PRIME) + y = pack(ids.point.y, PRIME) + + value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P + %} + + let (new_x: BigInt3) = nondet_bigint3(); + return (res=new_x); +} + +func main{range_check_ptr}() { + let p = EcPoint(BigInt3(1,2,3), BigInt3(4,5,6)); + let s = BigInt3(7,8,9); + let (res) = ec_double(p, s); + assert res.d0 = 21935; + assert res.d1 = 12420; + assert res.d2 = 184; + return (); +} diff --git a/cairo_programs/proof_programs/usort.cairo b/cairo_programs/proof_programs/usort.cairo new file mode 120000 index 00000000..4a7e19e6 --- /dev/null +++ b/cairo_programs/proof_programs/usort.cairo @@ -0,0 +1 @@ +../usort.cairo \ No newline at end of file diff --git a/cairo_programs/usort.cairo b/cairo_programs/usort.cairo new file mode 100644 index 00000000..e5859b29 --- /dev/null +++ b/cairo_programs/usort.cairo @@ -0,0 +1,22 @@ +%builtins range_check +from starkware.cairo.common.usort import usort +from starkware.cairo.common.alloc import alloc + +func main{range_check_ptr}() -> () { + alloc_locals; + let (input_array: felt*) = alloc(); + assert input_array[0] = 2; + assert input_array[1] = 1; + assert input_array[2] = 0; + + let (output_len, output, multiplicities) = usort(input_len=3, input=input_array); + + assert output_len = 3; + assert output[0] = 0; + assert output[1] = 1; + assert output[2] = 2; + assert multiplicities[0] = 1; + assert multiplicities[1] = 1; + assert multiplicities[2] = 1; + return (); +} diff --git a/pkg/hints/ec_hint.go b/pkg/hints/ec_hint.go index 005523e0..04f0587c 100644 --- a/pkg/hints/ec_hint.go +++ b/pkg/hints/ec_hint.go @@ -7,6 +7,7 @@ import ( "github.com/lambdaclass/cairo-vm.go/pkg/builtins" "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + . "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" "github.com/lambdaclass/cairo-vm.go/pkg/types" . "github.com/lambdaclass/cairo-vm.go/pkg/types" "github.com/lambdaclass/cairo-vm.go/pkg/vm" @@ -184,6 +185,48 @@ func computeSlope(vm *VirtualMachine, execScopes ExecutionScopes, idsData IdsMan return nil } +// Implements hint: +// from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack +// +// slope = pack(ids.slope, PRIME) +// x = pack(ids.point.x, PRIME) +// y = pack(ids.point.y, PRIME) +// +// value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P +func ecDoubleAssignNewX(vm *VirtualMachine, execScopes ExecutionScopes, ids IdsManager, secpP big.Int) error { + execScopes.AssignOrUpdateVariable("SECP_P", secpP) + + slope3, err := BigInt3FromVarName("slope", ids, vm) + if err != nil { + return err + } + packedSlope := slope3.Pack86() + slope := new(big.Int).Mod(&packedSlope, Prime()) + point, err := EcPointFromVarName("point", vm, ids) + if err != nil { + return err + } + + xPacked := point.X.Pack86() + x := new(big.Int).Mod(&xPacked, Prime()) + yPacked := point.Y.Pack86() + y := new(big.Int).Mod(&yPacked, Prime()) + + value := new(big.Int).Mul(slope, slope) + value = value.Mod(value, &secpP) + + value = value.Sub(value, x) + value = value.Sub(value, x) + value = value.Mod(value, &secpP) + + execScopes.AssignOrUpdateVariable("slope", slope) + execScopes.AssignOrUpdateVariable("x", x) + execScopes.AssignOrUpdateVariable("y", y) + execScopes.AssignOrUpdateVariable("value", *value) + execScopes.AssignOrUpdateVariable("new_x", *value) + return nil +} + /* Implements hint: %{ from starkware.cairo.common.cairo_secp.secp256r1_utils import SECP256R1_ALPHA as ALPHA %} diff --git a/pkg/hints/ec_hint_test.go b/pkg/hints/ec_hint_test.go index 3cf08a82..974576b4 100644 --- a/pkg/hints/ec_hint_test.go +++ b/pkg/hints/ec_hint_test.go @@ -235,6 +235,75 @@ func TestRunComputeSlopeOk(t *testing.T) { } } +func TestEcDoubleAssignNewXOk(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "slope": { + NewMaybeRelocatableFelt(FeltFromUint64(3)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + "point": { + // X + NewMaybeRelocatableFelt(FeltFromUint64(2)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + // Y + NewMaybeRelocatableFelt(FeltFromUint64(4)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + NewMaybeRelocatableFelt(FeltFromUint64(0)), + }, + }, + vm, + ) + + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: EC_DOUBLE_ASSIGN_NEW_X_V1, + }) + + execScopes := types.NewExecutionScopes() + err := hintProcessor.ExecuteHint(vm, &hintData, nil, execScopes) + + if err != nil { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed with error: %s", err) + } + + slopeUncast, _ := execScopes.Get("slope") + slope := slopeUncast.(*big.Int) + xUncast, _ := execScopes.Get("x") + x := xUncast.(*big.Int) + yUncast, _ := execScopes.Get("y") + y := yUncast.(*big.Int) + valueUncast, _ := execScopes.Get("value") + value := valueUncast.(big.Int) + new_xUncast, _ := execScopes.Get("new_x") + new_x := new_xUncast.(big.Int) + + if value.Cmp(&new_x) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: new_x != value. %v != %v", new_x, value) + } + expectedRes := big.NewInt(5) + if value.Cmp(expectedRes) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected value (%v) to be 6", value) + } + expectedSlope := big.NewInt(3) + if slope.Cmp(expectedSlope) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected slope (%v) to be 3", slope) + } + expectedX := big.NewInt(2) + if x.Cmp(expectedX) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected x (%v) to be 2", x) + } + expectedY := big.NewInt(4) + if y.Cmp(expectedY) != 0 { + t.Errorf("EC_DOUBLE_ASSIGN_NEW_X hint failed: expected y (%v) to be 4", y) + } +} + func TestRunComputeSlopeV2Ok(t *testing.T) { vm := NewVirtualMachine() diff --git a/pkg/hints/hint_codes/ec_op_hints.go b/pkg/hints/hint_codes/ec_op_hints.go index 471246db..e8ab665c 100644 --- a/pkg/hints/hint_codes/ec_op_hints.go +++ b/pkg/hints/hint_codes/ec_op_hints.go @@ -4,10 +4,39 @@ const EC_NEGATE = "from starkware.cairo.common.cairo_secp.secp_utils import SECP const EC_NEGATE_EMBEDDED_SECP = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nSECP_P = 2**255-19\n\ny = pack(ids.point.y, PRIME) % SECP_P\n# The modulo operation in python always returns a nonnegative number.\nvalue = (-y) % SECP_P" const EC_DOUBLE_SLOPE_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\nfrom starkware.python.math_utils import ec_double_slope\n\n# Compute the slope.\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\nvalue = slope = ec_double_slope(point=(x, y), alpha=0, p=SECP_P)" const COMPUTE_SLOPE_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\nfrom starkware.python.math_utils import line_slope\n\n# Compute the slope.\nx0 = pack(ids.point0.x, PRIME)\ny0 = pack(ids.point0.y, PRIME)\nx1 = pack(ids.point1.x, PRIME)\ny1 = pack(ids.point1.y, PRIME)\nvalue = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" -const EC_DOUBLE_SLOPE_EXTERNAL_CONSTS = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nfrom starkware.python.math_utils import ec_double_slope\n\n# Compute the slope.\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\nvalue = slope = ec_double_slope(point=(x, y), alpha=ALPHA, p=SECP_P)" -const NONDET_BIGINT3_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import split\n\nsegments.write_arg(ids.res.address_, split(value))" +const EC_DOUBLE_ASSIGN_NEW_X_V1 = `from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + +slope = pack(ids.slope, PRIME) +x = pack(ids.point.x, PRIME) +y = pack(ids.point.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` +const EC_DOUBLE_ASSIGN_NEW_X_V2 = `from starkware.cairo.common.cairo_secp.secp_utils import pack + +slope = pack(ids.slope, PRIME) +x = pack(ids.point.x, PRIME) +y = pack(ids.point.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` +const EC_DOUBLE_ASSIGN_NEW_X_V3 = `from starkware.cairo.common.cairo_secp.secp_utils import pack +SECP_P = 2**255-19 + +slope = pack(ids.slope, PRIME) +x = pack(ids.point.x, PRIME) +y = pack(ids.point.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` +const EC_DOUBLE_ASSIGN_NEW_X_V4 = `from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack + +slope = pack(ids.slope, PRIME) +x = pack(ids.pt.x, PRIME) +y = pack(ids.pt.y, PRIME) + +value = new_x = (pow(slope, 2, SECP_P) - 2 * x) % SECP_P` const COMPUTE_SLOPE_V2 = "from starkware.python.math_utils import line_slope\nfrom starkware.cairo.common.cairo_secp.secp_utils import pack\nSECP_P = 2**255-19\n# Compute the slope.\nx0 = pack(ids.point0.x, PRIME)\ny0 = pack(ids.point0.y, PRIME)\nx1 = pack(ids.point1.x, PRIME)\ny1 = pack(ids.point1.y, PRIME)\nvalue = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" const COMPUTE_SLOPE_WHITELIST = "from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack\nfrom starkware.python.math_utils import div_mod\n\n# Compute the slope.\nx0 = pack(ids.pt0.x, PRIME)\ny0 = pack(ids.pt0.y, PRIME)\nx1 = pack(ids.pt1.x, PRIME)\ny1 = pack(ids.pt1.y, PRIME)\nvalue = slope = div_mod(y0 - y1, x0 - x1, SECP_P)" +const EC_DOUBLE_SLOPE_EXTERNAL_CONSTS = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nfrom starkware.python.math_utils import ec_double_slope\n\n# Compute the slope.\nx = pack(ids.point.x, PRIME)\ny = pack(ids.point.y, PRIME)\nvalue = slope = ec_double_slope(point=(x, y), alpha=ALPHA, p=SECP_P)" +const NONDET_BIGINT3_V1 = "from starkware.cairo.common.cairo_secp.secp_utils import split\n\nsegments.write_arg(ids.res.address_, split(value))" const COMPUTE_SLOPE_SECP256R1 = "from starkware.cairo.common.cairo_secp.secp_utils import pack\nfrom starkware.python.math_utils import line_slope\n\n# Compute the slope.\nx0 = pack(ids.point0.x, PRIME)\ny0 = pack(ids.point0.y, PRIME)\nx1 = pack(ids.point1.x, PRIME)\ny1 = pack(ids.point1.y, PRIME)\nvalue = slope = line_slope(point1=(x0, y0), point2=(x1, y1), p=SECP_P)" const FAST_EC_ADD_ASSIGN_NEW_X = `"from starkware.cairo.common.cairo_secp.secp_utils import SECP_P, pack diff --git a/pkg/hints/hint_codes/usort_hint_codes.go b/pkg/hints/hint_codes/usort_hint_codes.go new file mode 100644 index 00000000..7dd3e944 --- /dev/null +++ b/pkg/hints/hint_codes/usort_hint_codes.go @@ -0,0 +1,32 @@ +package hint_codes + +const USORT_ENTER_SCOPE = "vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size')))" + +const USORT_BODY = `from collections import defaultdict + +input_ptr = ids.input +input_len = int(ids.input_len) +if __usort_max_size is not None: + assert input_len <= __usort_max_size, ( + f"usort() can only be used with input_len<={__usort_max_size}. " + f"Got: input_len={input_len}." + ) + +positions_dict = defaultdict(list) +for i in range(input_len): + val = memory[input_ptr + i] + positions_dict[val].append(i) + +output = sorted(positions_dict.keys()) +ids.output_len = len(output) +ids.output = segments.gen_arg(output) +ids.multiplicities = segments.gen_arg([len(positions_dict[k]) for k in output])` + +const USORT_VERIFY = `last_pos = 0 +positions = positions_dict[ids.value][::-1]` + +const USORT_VERIFY_MULTIPLICITY_ASSERT = "assert len(positions) == 0" + +const USORT_VERIFY_MULTIPLICITY_BODY = `current_pos = positions.pop() +ids.next_item_index = current_pos - last_pos +last_pos = current_pos + 1` diff --git a/pkg/hints/hint_processor.go b/pkg/hints/hint_processor.go index 6ed7562f..ccac7bd2 100644 --- a/pkg/hints/hint_processor.go +++ b/pkg/hints/hint_processor.go @@ -92,6 +92,8 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return ecNegateImportSecpP(vm, *execScopes, data.Ids) case EC_NEGATE_EMBEDDED_SECP: return ecNegateEmbeddedSecpP(vm, *execScopes, data.Ids) + case EC_DOUBLE_ASSIGN_NEW_X_V1, EC_DOUBLE_ASSIGN_NEW_X_V2, EC_DOUBLE_ASSIGN_NEW_X_V3, EC_DOUBLE_ASSIGN_NEW_X_V4: + return ecDoubleAssignNewX(vm, *execScopes, data.Ids, SECP_P_V2()) case POW: return pow(data.Ids, vm) case SQRT: @@ -106,6 +108,16 @@ func (p *CairoVmHintProcessor) ExecuteHint(vm *vm.VirtualMachine, hintData *any, return memset_step_loop(data.Ids, vm, execScopes, "continue_loop") case VM_ENTER_SCOPE: return vm_enter_scope(execScopes) + case USORT_ENTER_SCOPE: + return usortEnterScope(execScopes) + case USORT_BODY: + return usortBody(data.Ids, execScopes, vm) + case USORT_VERIFY: + return usortVerify(data.Ids, execScopes, vm) + case USORT_VERIFY_MULTIPLICITY_ASSERT: + return usortVerifyMultiplicityAssert(execScopes) + case USORT_VERIFY_MULTIPLICITY_BODY: + return usortVerifyMultiplicityBody(data.Ids, execScopes, vm) case SET_ADD: return setAdd(data.Ids, vm) case FIND_ELEMENT: diff --git a/pkg/hints/usort_hints.go b/pkg/hints/usort_hints.go new file mode 100644 index 00000000..1dc0f364 --- /dev/null +++ b/pkg/hints/usort_hints.go @@ -0,0 +1,266 @@ +package hints + +import ( + "fmt" + "sort" + + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/pkg/errors" +) + +// SortFelt implements sort.Interface for []lambdaworks.Felt +type SortFelt []lambdaworks.Felt + +func (s SortFelt) Len() int { return len(s) } +func (s SortFelt) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s SortFelt) Less(i, j int) bool { + a, b := s[i], s[j] + + return a.Cmp(b) == -1 +} + +// Implements hint: +// %{ vm_enter_scope(dict(__usort_max_size = globals().get('__usort_max_size'))) %} +func usortEnterScope(executionScopes *types.ExecutionScopes) error { + usort_max_size_interface, err := executionScopes.Get("usort_max_size") + + if err != nil { + executionScopes.EnterScope(make(map[string]interface{})) + return nil + } + + usort_max_size, cast_ok := usort_max_size_interface.(uint64) + + if !cast_ok { + return errors.New("Error casting usort_max_size into a uint64") + } + + scope := make(map[string]interface{}) + scope["usort_max_size"] = usort_max_size + executionScopes.EnterScope(scope) + + return nil +} + +func usortBody(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + input_ptr, err := ids.GetRelocatable("input", vm) + if err != nil { + return err + } + + input_len, err := ids.GetFelt("input_len", vm) + + if err != nil { + return err + } + input_len_u64, err := input_len.ToU64() + + if err != nil { + return err + } + + usort_max_size, err := executionScopes.Get("usort_max_size") + + if err == nil { + usort_max_size_u64, cast_ok := usort_max_size.(uint64) + + if !cast_ok { + return errors.New("Error casting usort_max_size into a uint64") + } + + if input_len_u64 > usort_max_size_u64 { + return errors.New(fmt.Sprintf("usort() can only be used with input_len<= %v. Got: input_len=%v.", usort_max_size_u64, input_len_u64)) + } + } + + positions_dict := make(map[lambdaworks.Felt][]uint64) + + for i := uint64(0); i < input_len_u64; i++ { + + val, err := vm.Segments.Memory.GetFelt(input_ptr.AddUint(uint(i))) + + if err != nil { + return err + } + + positions_dict[val] = append(positions_dict[val], i) + } + executionScopes.AssignOrUpdateVariable("positions_dict", positions_dict) + + output := make([]lambdaworks.Felt, 0, len(positions_dict)) + + for key := range positions_dict { + output = append(output, key) + } + + sort.Sort(SortFelt(output)) + + output_len := memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(uint64((len(output))))) + err = ids.Insert("output_len", output_len, vm) + + if err != nil { + return err + } + + output_base := vm.Segments.AddSegment() + + for i := range output { + err = vm.Segments.Memory.Insert(output_base.AddUint(uint(i)), memory.NewMaybeRelocatableFelt(output[i])) + + if err != nil { + return err + } + } + + multiplicities_base := vm.Segments.AddSegment() + + multiplicities := make([]uint64, 0, len(output)) + + for key := range output { + multiplicities = append(multiplicities, uint64(len(positions_dict[output[key]]))) + } + + for i := range multiplicities { + err = vm.Segments.Memory.Insert(multiplicities_base.AddUint(uint(i)), memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(multiplicities[i]))) + + if err != nil { + return err + } + } + + err = ids.Insert("output", memory.NewMaybeRelocatableRelocatable(output_base), vm) + + if err != nil { + return err + } + + err = ids.Insert("multiplicities", memory.NewMaybeRelocatableRelocatable(multiplicities_base), vm) + + if err != nil { + return err + } + + return nil +} + +// Implements hint: +// +// %{ +// last_pos = 0 +// positions = positions_dict[ids.value][::-1] +// %} +func usortVerify(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + executionScopes.AssignOrUpdateVariable("last_pos", uint64(0)) + + positions_dict_interface, err := executionScopes.Get("positions_dict") + + if err != nil { + return err + } + + positions_dict, cast_ok := positions_dict_interface.(map[lambdaworks.Felt][]uint64) + + if !cast_ok { + return errors.New("Error casting positions_dict") + } + + value, err := ids.GetFelt("value", vm) + if err != nil { + return err + } + + if err != nil { + return err + } + + positions := positions_dict[value] + + for i, j := 0, len(positions)-1; i < j; i, j = i+1, j-1 { + positions[i], positions[j] = positions[j], positions[i] + } + + executionScopes.AssignOrUpdateVariable("positions", positions) + + return nil +} + +// Implements hint: +// %{ assert len(positions) == 0 %} +func usortVerifyMultiplicityAssert(executionScopes *types.ExecutionScopes) error { + + positions_interface, err := executionScopes.Get("positions") + + if err != nil { + return err + } + + positions, cast_ok := positions_interface.([]uint64) + + if !cast_ok { + return errors.New("Error casting positions to []uint64") + } + + if len(positions) != 0 { + return errors.New("Assertion failed: len(positions) == 0") + } + + return nil + +} + +// Implements hint: +// +// %{ +// current_pos = positions.pop() +// ids.next_item_index = current_pos - last_pos +// last_pos = current_pos + 1 +// %} +func usortVerifyMultiplicityBody(ids IdsManager, executionScopes *types.ExecutionScopes, vm *VirtualMachine) error { + + positions_interface, err := executionScopes.Get("positions") + + if err != nil { + return err + } + + positions, cast_ok := positions_interface.([]uint64) + + if !cast_ok { + return errors.New("Error casting positions to []uint64") + } + + last_pos_interface, err := executionScopes.Get("last_pos") + + if err != nil { + return err + } + + last_pos, cast_ok := last_pos_interface.(uint64) + + if !cast_ok { + return errors.New("Error casting last_pos to uint64") + } + + current_pos := positions[len(positions)-1] + + executionScopes.AssignOrUpdateVariable("positions", positions[:len(positions)-1]) + + next_item_index := current_pos - last_pos + + err = ids.Insert("next_item_index", memory.NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(next_item_index)), vm) + + if err != nil { + return err + } + + executionScopes.AssignOrUpdateVariable("last_pos", current_pos+1) + + return nil +} diff --git a/pkg/hints/usort_hints_test.go b/pkg/hints/usort_hints_test.go new file mode 100644 index 00000000..c390869e --- /dev/null +++ b/pkg/hints/usort_hints_test.go @@ -0,0 +1,240 @@ +package hints_test + +import ( + "reflect" + "sort" + "testing" + + . "github.com/lambdaclass/cairo-vm.go/pkg/hints" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_codes" + . "github.com/lambdaclass/cairo-vm.go/pkg/hints/hint_utils" + "github.com/lambdaclass/cairo-vm.go/pkg/lambdaworks" + "github.com/lambdaclass/cairo-vm.go/pkg/types" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm" + . "github.com/lambdaclass/cairo-vm.go/pkg/vm/memory" +) + +func TestSortFeltArray(t *testing.T) { + array := []lambdaworks.Felt{lambdaworks.FeltFromUint(6), lambdaworks.FeltFromUint(0), lambdaworks.FeltFromUint(100), lambdaworks.FeltFromUint(1), lambdaworks.FeltFromUint(50)} + + sort.Sort(SortFelt(array)) + + sortedarray := []lambdaworks.Felt{lambdaworks.FeltFromUint(0), lambdaworks.FeltFromUint(1), lambdaworks.FeltFromUint(6), lambdaworks.FeltFromUint(50), lambdaworks.FeltFromUint(100)} + + if !reflect.DeepEqual(array, sortedarray) { + t.Errorf("Error sorting felt array") + } + +} + +func TestUsortWithMaxSize(t *testing.T) { + vm := NewVirtualMachine() + scopes := types.NewExecutionScopes() + scopes.AssignOrUpdateVariable("usort_max_size", uint64(1)) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_ENTER_SCOPE, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_ENTER_SCOPE hint execution failed") + } + + usort_max_size_interface, err := scopes.Get("usort_max_size") + + if err != nil { + t.Errorf("Error assigning usort_max_size") + } + + usort_max_size := usort_max_size_interface.(uint64) + + if usort_max_size != uint64(1) { + t.Errorf("Error assigning usort_max_size") + } + +} +func TestUsortOutOfRange(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + scopes.AssignOrUpdateVariable("usort_max_size", uint64(1)) + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "input": {NewMaybeRelocatableRelocatable(NewRelocatable(2, 1))}, + "input_len": {NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(5))}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_BODY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_BODY hint should have failed") + } + +} + +func TestUsortVerify(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + positions_dict := make(map[lambdaworks.Felt][]uint64) + positions_dict[lambdaworks.FeltFromUint64(0)] = []uint64{2} + positions_dict[lambdaworks.FeltFromUint64(1)] = []uint64{1} + positions_dict[lambdaworks.FeltFromUint64(2)] = []uint64{0} + + scopes.AssignOrUpdateVariable("positions_dict", positions_dict) + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "value": {NewMaybeRelocatableFelt(lambdaworks.FeltFromUint64(0))}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY failed") + } + + positions_interface, err := scopes.Get("positions") + + if err != nil { + t.Errorf("Error assigning positions_interface") + } + + positions := positions_interface.([]uint64) + + if !reflect.DeepEqual(positions, []uint64{2}) { + t.Errorf("Error assigning positions") + } + + last_pos_interface, err := scopes.Get("last_pos") + + if err != nil { + t.Errorf("Error assigning last_pos") + } + + last_pos := last_pos_interface.(uint64) + + if last_pos != uint64(0) { + t.Errorf("Error assigning last_pos") + } + +} + +func TestUsortVerifyMultiplicityAssert(t *testing.T) { + vm := NewVirtualMachine() + scopes := types.NewExecutionScopes() + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{}, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY_MULTIPLICITY_ASSERT, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT should have failed") + } + + positions := []uint64{0} + + scopes.AssignOrUpdateVariable("positions", positions) + + err = hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err == nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT should have failed") + } + + positions = []uint64{} + + scopes.AssignOrUpdateVariable("positions", positions) + + err = hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_ASSERT failed") + } + +} + +func TestUsortVerifyMultiplicityBody(t *testing.T) { + vm := NewVirtualMachine() + vm.Segments.AddSegment() + scopes := types.NewExecutionScopes() + + scopes.AssignOrUpdateVariable("positions", []uint64{1, 0, 4, 7, 10}) + scopes.AssignOrUpdateVariable("last_pos", uint64(3)) + + idsManager := SetupIdsForTest( + map[string][]*MaybeRelocatable{ + "next_item_index": {nil}, + }, + vm, + ) + hintProcessor := CairoVmHintProcessor{} + hintData := any(HintData{ + Ids: idsManager, + Code: USORT_VERIFY_MULTIPLICITY_BODY, + }) + err := hintProcessor.ExecuteHint(vm, &hintData, nil, scopes) + if err != nil { + t.Errorf("USORT_VERIFY_MULTIPLICITY_BODY failed") + } + + // Check scopes variables + positions_interface, err := scopes.Get("positions") + + if err != nil { + t.Errorf("Error assigning positions_interface") + } + + positions := positions_interface.([]uint64) + + if !reflect.DeepEqual(positions, []uint64{1, 0, 4, 7}) { + t.Errorf("Error assigning positions") + } + + last_pos_interface, err := scopes.Get("last_pos") + + if err != nil { + t.Errorf("Error assigning last_pos") + } + + last_pos := last_pos_interface.(uint64) + + if last_pos != uint64(11) { + t.Errorf("Error assigning last_pos") + } + + // Check VM inserts + next_item_index, err := idsManager.GetFelt("next_item_index", vm) + + if err != nil { + t.Errorf("Error assigning next_item_index") + } + + if next_item_index != lambdaworks.FeltFromUint(7) { + t.Errorf("Error assigning next_item_index") + } + +} diff --git a/pkg/types/exec_scope_test.go b/pkg/types/exec_scope_test.go index 119f266d..6eb1d97e 100644 --- a/pkg/types/exec_scope_test.go +++ b/pkg/types/exec_scope_test.go @@ -258,3 +258,21 @@ func TestErrExitMainScope(t *testing.T) { t.Errorf("TestErrExitMainScope should fail with error: %s and fails with: %s", types.ErrCannotExitMainScop, err) } } + +func TestFetchScopeVar(t *testing.T) { + scope := make(map[string]interface{}) + scope["k"] = lambdaworks.FeltOne() + + scopes := types.NewExecutionScopes() + scopes.EnterScope(scope) + + result, err := types.FetchScopeVar[lambdaworks.Felt]("k", scopes) + if err != nil { + t.Errorf("TestGetLocalVariables failed with error: %s", err) + + } + expected := lambdaworks.FeltOne() + if expected != result { + t.Errorf("TestGetLocalVariables failed, expected: %s, got: %s", expected.ToSignedFeltString(), result.ToSignedFeltString()) + } +} diff --git a/pkg/types/exec_scopes.go b/pkg/types/exec_scopes.go index 357a1cec..bb372f40 100644 --- a/pkg/types/exec_scopes.go +++ b/pkg/types/exec_scopes.go @@ -18,6 +18,10 @@ func ErrVariableNotInScope(varName string) error { return ExecutionScopesError(errors.Errorf("Variable %s not in scope", varName)) } +func ErrVariableHasWrongType(varName string) error { + return ExecutionScopesError(errors.Errorf("Scope variable %s has wrong type", varName)) +} + func NewExecutionScopes() *ExecutionScopes { data := make([]map[string]interface{}, 1) data[0] = make(map[string]interface{}) @@ -82,3 +86,20 @@ func (es *ExecutionScopes) Get(varName string) (interface{}, error) { } return val, nil } + +// Generic version of ExecutionScopes.Get which also handles casting +func FetchScopeVar[T interface{}](varName string, scopes *ExecutionScopes) (T, error) { + locals, err := scopes.GetLocalVariables() + if err != nil { + return *new(T), err + } + valAny, prs := locals[varName] + if !prs { + return *new(T), ErrVariableNotInScope(varName) + } + val, ok := valAny.(T) + if !ok { + return *new(T), ErrVariableHasWrongType(varName) + } + return val, nil +} diff --git a/pkg/vm/cairo_run/cairo_run_test.go b/pkg/vm/cairo_run/cairo_run_test.go index 3f54eacd..845a3076 100644 --- a/pkg/vm/cairo_run/cairo_run_test.go +++ b/pkg/vm/cairo_run/cairo_run_test.go @@ -309,6 +309,14 @@ func TestSplitFeltHint(t *testing.T) { testProgram("split_felt", t) } +func TestUsort(t *testing.T) { + testProgram("usort", t) +} + +func TestUsortProofMode(t *testing.T) { + testProgramProof("usort", t) +} + func TestSplitFeltHintProofMode(t *testing.T) { testProgramProof("split_felt", t) } @@ -321,6 +329,10 @@ func TestSplitIntHintProofMode(t *testing.T) { testProgramProof("split_int", t) } +func TestEcDoubleAssign(t *testing.T) { + testProgram("ec_double_assign", t) +} + func TestIntegrationEcDoubleSlope(t *testing.T) { testProgram("ec_double_slope", t) }