Skip to content

Commit

Permalink
Merge pull request #103 from 0xPolygonHermez/feature/serialWitness
Browse files Browse the repository at this point in the history
serial witness retrieval
  • Loading branch information
joanestebanr authored Jul 15, 2024
2 parents 93f964f + be09346 commit 160c58e
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 51 deletions.
56 changes: 28 additions & 28 deletions aggregator/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func (a *Aggregator) handleReceivedDataStream(entry *datastreamer.FileEntry, cli
a.currentStreamBatch.Timestamp = sequence.Timestamp

// Calculate Acc Input Hash
oldBatch, _, err := a.state.GetBatch(ctx, a.currentStreamBatch.BatchNumber-1, nil)
oldBatch, _, _, err := a.state.GetBatch(ctx, a.currentStreamBatch.BatchNumber-1, nil)
if err != nil {
log.Errorf("Error getting batch %d: %v", a.currentStreamBatch.BatchNumber-1, err)
return err
Expand All @@ -343,7 +343,14 @@ func (a *Aggregator) handleReceivedDataStream(entry *datastreamer.FileEntry, cli

a.currentStreamBatch.AccInputHash = accInputHash

err = a.state.AddBatch(ctx, &a.currentStreamBatch, a.currentBatchStreamData, nil)
// Get Witness
witness, err := getWitness(a.currentStreamBatch.BatchNumber, a.cfg.WitnessURL, a.cfg.UseFullWitness)
if err != nil {
log.Errorf("Failed to get witness for batch %d, err: %v", a.currentStreamBatch.BatchNumber, err)
return err
}

err = a.state.AddBatch(ctx, &a.currentStreamBatch, a.currentBatchStreamData, witness, nil)
if err != nil {
log.Errorf("Error adding batch: %v", err)
return err
Expand Down Expand Up @@ -467,7 +474,7 @@ func (a *Aggregator) Start(ctx context.Context) error {

// Store Acc Input Hash of the latest verified batch
dummyBatch := state.Batch{BatchNumber: lastVerifiedBatchNumber, AccInputHash: *accInputHash}
err = a.state.AddBatch(ctx, &dummyBatch, []byte{0}, nil)
err = a.state.AddBatch(ctx, &dummyBatch, []byte{0}, []byte{0}, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -624,7 +631,7 @@ func (a *Aggregator) sendFinalProof() {

a.startProofVerification()

finalBatch, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
finalBatch, _, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
if err != nil {
log.Errorf("Failed to retrieve batch with number [%d]: %v", proof.BatchNumberFinal, err)
a.endProofVerification()
Expand Down Expand Up @@ -770,7 +777,7 @@ func (a *Aggregator) buildFinalProof(ctx context.Context, prover proverInterface
if string(finalProof.Public.NewStateRoot) == mockedStateRoot && string(finalProof.Public.NewLocalExitRoot) == mockedLocalExitRoot {
// This local exit root and state root come from the mock
// prover, use the one captured by the executor instead
finalBatch, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
finalBatch, _, _, err := a.state.GetBatch(ctx, proof.BatchNumberFinal, nil)
if err != nil {
return nil, fmt.Errorf("failed to retrieve batch with number [%d]", proof.BatchNumberFinal)
}
Expand Down Expand Up @@ -1171,7 +1178,7 @@ func (a *Aggregator) getVerifiedBatchAccInputHash(ctx context.Context, batchNumb
return &accInputHash, nil
}

func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverInterface) (*state.Batch, *state.Proof, error) {
func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverInterface) (*state.Batch, []byte, *state.Proof, error) {
proverID := prover.ID()
proverName := prover.Name()

Expand All @@ -1187,7 +1194,7 @@ func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverIn
// Get last virtual batch number from L1
lastVerifiedBatchNumber, err := a.etherman.GetLatestVerifiedBatchNum()
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

proofExists := true
Expand All @@ -1199,20 +1206,20 @@ func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverIn
proofExists, err = a.state.CheckProofExistsForBatch(ctx, batchNumberToVerify, nil)
if err != nil {
log.Infof("Error checking proof exists for batch %d", batchNumberToVerify)
return nil, nil, err
return nil, nil, nil, err
}
}

// Check if the batch has been sequenced
sequence, err := a.l1Syncr.GetSequenceByBatchNumber(ctx, batchNumberToVerify)
if err != nil && !errors.Is(err, entities.ErrNotFound) {
return nil, nil, err
return nil, nil, nil, err
}

// Not found, so it it not possible to verify the batch yet
if sequence == nil || errors.Is(err, entities.ErrNotFound) {
log.Infof("No sequence found for batch %d", batchNumberToVerify)
return nil, nil, state.ErrNotFound
return nil, nil, nil, state.ErrNotFound
}

stateSequence := state.Sequence{
Expand All @@ -1223,12 +1230,12 @@ func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverIn
err = a.state.AddSequence(ctx, stateSequence, nil)
if err != nil {
log.Infof("Error storing sequence for batch %d", batchNumberToVerify)
return nil, nil, err
return nil, nil, nil, err
}

batch, _, err := a.state.GetBatch(ctx, batchNumberToVerify, nil)
batch, _, witness, err := a.state.GetBatch(ctx, batchNumberToVerify, nil)
if err != nil {
return batch, nil, err
return batch, witness, nil, err
}

// All the data required to generate a proof is ready
Expand All @@ -1241,12 +1248,12 @@ func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverIn
isProfitable, err := a.profitabilityChecker.IsProfitable(ctx, big.NewInt(0))
if err != nil {
log.Errorf("Failed to check aggregator profitability, err: %v", err)
return nil, nil, err
return nil, nil, nil, err
}

if !isProfitable {
log.Infof("Batch is not profitable, pol collateral %d", big.NewInt(0))
return nil, nil, err
return nil, nil, nil, err
}

now := time.Now().Round(time.Microsecond)
Expand All @@ -1262,10 +1269,10 @@ func (a *Aggregator) getAndLockBatchToProve(ctx context.Context, prover proverIn
err = a.state.AddGeneratedProof(ctx, proof, nil)
if err != nil {
log.Errorf("Failed to add batch proof, err: %v", err)
return nil, nil, err
return nil, nil, nil, err
}

return batch, proof, nil
return batch, witness, proof, nil
}

func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInterface) (bool, error) {
Expand All @@ -1276,7 +1283,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInt
)
log.Debug("tryGenerateBatchProof start")

batchToProve, proof, err0 := a.getAndLockBatchToProve(ctx, prover)
batchToProve, witness, proof, err0 := a.getAndLockBatchToProve(ctx, prover)
if errors.Is(err0, state.ErrNotFound) {
// nothing to proof, swallow the error
log.Debug("Nothing to generate proof")
Expand Down Expand Up @@ -1305,7 +1312,7 @@ func (a *Aggregator) tryGenerateBatchProof(ctx context.Context, prover proverInt
}()

log.Infof("Sending zki + batch to the prover, batchNumber [%d]", batchToProve.BatchNumber)
inputProver, err := a.buildInputProver(ctx, batchToProve)
inputProver, err := a.buildInputProver(ctx, batchToProve, witness)
if err != nil {
err = fmt.Errorf("failed to build input prover, %w", err)
log.Error(FirstToUpper(err.Error()))
Expand Down Expand Up @@ -1397,7 +1404,7 @@ func (a *Aggregator) resetVerifyProofTime() {
a.timeSendFinalProof = time.Now().Add(a.cfg.VerifyProofInterval.Duration)
}

func (a *Aggregator) buildInputProver(ctx context.Context, batchToVerify *state.Batch) (*prover.StatelessInputProver, error) {
func (a *Aggregator) buildInputProver(ctx context.Context, batchToVerify *state.Batch, witness []byte) (*prover.StatelessInputProver, error) {
isForcedBatch := false
batchRawData := &state.BatchRawV2{}
var err error
Expand Down Expand Up @@ -1491,15 +1498,8 @@ func (a *Aggregator) buildInputProver(ctx context.Context, batchToVerify *state.
}*/
}

// Get Witness
witness, err := getWitness(batchToVerify.BatchNumber, a.cfg.WitnessURL, a.cfg.UseFullWitness)
if err != nil {
log.Errorf("Failed to get witness, err: %v", err)
return nil, err
}

// Get Old Acc Input Hash
oldBatch, _, err := a.state.GetBatch(ctx, batchToVerify.BatchNumber-1, nil)
oldBatch, _, _, err := a.state.GetBatch(ctx, batchToVerify.BatchNumber-1, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions aggregator/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ type stateInterface interface {
CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error)
CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error)
AddSequence(ctx context.Context, sequence state.Sequence, dbTx pgx.Tx) error
AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, error)
AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, witness []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, []byte, error)
DeleteBatchesOlderThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
DeleteBatchesNewerThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
}
15 changes: 4 additions & 11 deletions aggregator/prover/prover.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ func (p *Prover) Addr() string {

// Status gets the prover status.
func (p *Prover) Status() (*GetStatusResponse, error) {
start := time.Now()
req := &AggregatorMessage{
Request: &AggregatorMessage_GetStatusRequest{
GetStatusRequest: &GetStatusRequest{},
Expand All @@ -79,11 +78,9 @@ func (p *Prover) Status() (*GetStatusResponse, error) {
if err != nil {
return nil, err
}
log.Infof("Prover status call")
if msg, ok := res.Response.(*ProverMessage_GetStatusResponse); ok {
return msg.GetStatusResponse, nil
}
log.Infof("Prover %s status call took %v", p.ID(), time.Since(start))
return nil, fmt.Errorf("%w, wanted %T, got %T", ErrBadProverResponse, &ProverMessage_GetStatusResponse{}, res.Response)
}

Expand Down Expand Up @@ -119,12 +116,11 @@ func (p *Prover) BatchProof(input *StatelessInputProver) (*string, error) {
GenStatelessBatchProofRequest: &GenStatelessBatchProofRequest{Input: input},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return nil, err
}
log.Infof("Prover %s batch proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_GenBatchProofResponse); ok {
switch msg.GenBatchProofResponse.Result {
Expand Down Expand Up @@ -157,12 +153,11 @@ func (p *Prover) AggregatedProof(inputProof1, inputProof2 string) (*string, erro
},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return nil, err
}
log.Infof("Prover %s aggregated proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_GenAggregatedProofResponse); ok {
switch msg.GenAggregatedProofResponse.Result {
Expand Down Expand Up @@ -199,12 +194,11 @@ func (p *Prover) FinalProof(inputProof string, aggregatorAddr string) (*string,
},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return nil, err
}
log.Infof("Prover %s final proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_GenFinalProofResponse); ok {
switch msg.GenFinalProofResponse.Result {
Expand Down Expand Up @@ -235,12 +229,11 @@ func (p *Prover) CancelProofRequest(proofID string) error {
CancelRequest: &CancelRequest{Id: proofID},
},
}
start := time.Now()

res, err := p.call(req)
if err != nil {
return err
}
log.Infof("Prover %s cancel proof call took %v", p.ID(), time.Since(start))

if msg, ok := res.Response.(*ProverMessage_CancelResponse); ok {
switch msg.CancelResponse.Result {
Expand Down
8 changes: 8 additions & 0 deletions db/migrations/aggregator/002.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- +migrate Up
DELETE FROM aggregator.batch;
ALTER TABLE aggregator.batch
ADD COLUMN IF NOT EXISTS witness varchar NOT NULL;

-- +migrate Down
ALTER TABLE aggregator.batch
DROP COLUMN IF NOT EXISTS witness;
4 changes: 2 additions & 2 deletions state/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ type storage interface {
CleanupGeneratedProofs(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
CleanupLockedProofs(ctx context.Context, duration string, dbTx pgx.Tx) (int64, error)
CheckProofExistsForBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (bool, error)
AddBatch(ctx context.Context, batch *Batch, datastream []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*Batch, []byte, error)
AddBatch(ctx context.Context, batch *Batch, datastream []byte, witness []byte, dbTx pgx.Tx) error
GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*Batch, []byte, []byte, error)
DeleteBatchesOlderThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
DeleteBatchesNewerThanBatchNumber(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) error
}
17 changes: 9 additions & 8 deletions state/pgstatestorage/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,25 @@ import (
)

// AddBatch stores a batch
func (p *PostgresStorage) AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, dbTx pgx.Tx) error {
const addInputHashSQL = "INSERT INTO aggregator.batch (batch_num, batch, datastream) VALUES ($1, $2, $3) ON CONFLICT (batch_num) DO UPDATE SET batch = $2, datastream = $3"
func (p *PostgresStorage) AddBatch(ctx context.Context, batch *state.Batch, datastream []byte, witness []byte, dbTx pgx.Tx) error {
const addInputHashSQL = "INSERT INTO aggregator.batch (batch_num, batch, datastream, witness) VALUES ($1, $2, $3, $4) ON CONFLICT (batch_num) DO UPDATE SET batch = $2, datastream = $3, witness = $4"
e := p.getExecQuerier(dbTx)
_, err := e.Exec(ctx, addInputHashSQL, batch.BatchNumber, &batch, common.Bytes2Hex(datastream))
_, err := e.Exec(ctx, addInputHashSQL, batch.BatchNumber, &batch, common.Bytes2Hex(datastream), common.Bytes2Hex(witness))
return err
}

// GetBatch gets a batch by a given batch number
func (p *PostgresStorage) GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, error) {
const getInputHashSQL = "SELECT batch, datastream FROM aggregator.batch WHERE batch_num = $1"
func (p *PostgresStorage) GetBatch(ctx context.Context, batchNumber uint64, dbTx pgx.Tx) (*state.Batch, []byte, []byte, error) {
const getInputHashSQL = "SELECT batch, datastream, witness FROM aggregator.batch WHERE batch_num = $1"
e := p.getExecQuerier(dbTx)
var batch *state.Batch
var streamStr string
err := e.QueryRow(ctx, getInputHashSQL, batchNumber).Scan(&batch, &streamStr)
var witnessStr string
err := e.QueryRow(ctx, getInputHashSQL, batchNumber).Scan(&batch, &streamStr, &witnessStr)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}
return batch, common.Hex2Bytes(streamStr), nil
return batch, common.Hex2Bytes(streamStr), common.Hex2Bytes(witnessStr), nil
}

// DeleteBatchesOlderThanBatchNumber deletes batches previous to the given batch number
Expand Down

0 comments on commit 160c58e

Please sign in to comment.