diff --git a/README.md b/README.md index a718960..8eb9509 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,10 @@ environment variables must be set. Supply source and destination URL endpoints. - sqsmv -src https://region.queue.amazonaws.com/123/queue-a -dest https://region.queue.amazonaws.com/123/queue-b + sqsmv [-max 101] -src https://region.queue.amazonaws.com/123/queue-a -dest https://region.queue.amazonaws.com/123/queue-b + +The optional [-max int] flag allows one to specify the maximum number of messages to be moved from source to target. +If specified, the number must be greater than zero. If not specified, all available messages in the source queue will be moved. ## Seeing is believing :) diff --git a/main.go b/main.go index 8411db1..ec1f53e 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "flag" "log" + "math" "os" "sync" @@ -12,8 +13,10 @@ import ( ) func main() { + maxInt := int(^uint(0) >> 1) src := flag.String("src", "", "source queue") dest := flag.String("dest", "", "destination queue") + maxMsgsToMove := flag.Int("max", maxInt, "max number of messages to move") flag.Parse() if *src == "" || *dest == "" { @@ -23,6 +26,12 @@ func main() { log.Printf("source queue : %v", *src) log.Printf("destination queue : %v", *dest) + log.Printf("max number of messages to move : %v", *maxMsgsToMove) + + if *maxMsgsToMove <= 0 { + log.Printf("max number of message to move : %v must be greater than zero", *maxMsgsToMove) + os.Exit(1) + } // enable automatic use of AWS_PROFILE like awscli and other tools do. opts := session.Options{ @@ -36,7 +45,7 @@ func main() { client := sqs.New(session) - maxMessages := int64(10) + maxMessages := int64(math.Min(float64(*maxMsgsToMove), float64(10))) waitTime := int64(0) messageAttributeNames := aws.StringSlice([]string{"All"}) @@ -47,6 +56,8 @@ func main() { MessageAttributeNames: messageAttributeNames, } + var mutex = &sync.Mutex{} + var count int lastMessageCount := int(1) // loop as long as there are messages on the queue for { @@ -56,7 +67,7 @@ func main() { panic(err) } - if lastMessageCount == 0 && len(resp.Messages) == 0 { + if count >= *maxMsgsToMove || (lastMessageCount == 0 && len(resp.Messages) == 0) { // no messages returned twice now, the queue is probably empty log.Printf("done") return @@ -69,33 +80,45 @@ func main() { wg.Add(len(resp.Messages)) for _, m := range resp.Messages { + if count >= *maxMsgsToMove { + break + } + go func(m *sqs.Message) { defer wg.Done() - // write the message to the destination queue - smi := sqs.SendMessageInput{ - MessageAttributes: m.MessageAttributes, - MessageBody: m.Body, - QueueUrl: dest, - } - - _, err := client.SendMessage(&smi) - - if err != nil { - log.Printf("ERROR sending message to destination %v", err) - return - } - - // message was sent, dequeue from source queue - dmi := &sqs.DeleteMessageInput{ - QueueUrl: src, - ReceiptHandle: m.ReceiptHandle, + if count == maxInt { + mutex.Lock() + defer mutex.Unlock() } - if _, err := client.DeleteMessage(dmi); err != nil { - log.Printf("ERROR dequeueing message ID %v : %v", - *m.ReceiptHandle, - err) + if count < *maxMsgsToMove { + // write the message to the destination queue + smi := sqs.SendMessageInput{ + MessageAttributes: m.MessageAttributes, + MessageBody: m.Body, + QueueUrl: dest, + } + + _, err := client.SendMessage(&smi) + + if err != nil { + log.Printf("ERROR sending message to destination %v", err) + return + } + + // message was sent, dequeue from source queue + dmi := &sqs.DeleteMessageInput{ + QueueUrl: src, + ReceiptHandle: m.ReceiptHandle, + } + + if _, err := client.DeleteMessage(dmi); err != nil { + log.Printf("ERROR dequeueing message ID %v : %v", + *m.ReceiptHandle, + err) + } + count++ } }(m) }