diff --git a/pkg/exchange/batch/trade.go b/pkg/exchange/batch/trade.go index 43603a407c..1c91da7773 100644 --- a/pkg/exchange/batch/trade.go +++ b/pkg/exchange/batch/trade.go @@ -28,6 +28,8 @@ func (e TradeBatchQuery) Query(ctx context.Context, symbol string, options *type query := &AsyncTimeRangedBatchQuery{ Type: types.Trade{}, Q: func(startTime, endTime time.Time) (interface{}, error) { + options.StartTime = &startTime + options.EndTime = &endTime return e.ExchangeTradeHistoryService.QueryTrades(ctx, symbol, options) }, T: func(obj interface{}) time.Time { diff --git a/pkg/exchange/batch/trade_test.go b/pkg/exchange/batch/trade_test.go new file mode 100644 index 0000000000..1b704136c2 --- /dev/null +++ b/pkg/exchange/batch/trade_test.go @@ -0,0 +1,140 @@ +package batch + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + "github.com/c9s/bbgo/pkg/types" + "github.com/c9s/bbgo/pkg/types/mocks" +) + +func Test_TradeBatchQuery(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + var ( + ctx = context.Background() + timeNow = time.Now() + startTime = timeNow.Add(-24 * time.Hour) + endTime = timeNow + expSymbol = "BTCUSDT" + queryTrades1 = []types.Trade{ + { + ID: 1, + Time: types.Time(startTime.Add(time.Hour)), + }, + } + queryTrades2 = []types.Trade{ + { + ID: 2, + Time: types.Time(startTime.Add(time.Hour)), + }, + { + ID: 3, + Time: types.Time(startTime.Add(2 * time.Hour)), + }, + } + emptyTrades = []types.Trade{} + allRes = append(queryTrades1, queryTrades2...) + ) + + t.Run("succeeds", func(t *testing.T) { + var ( + expOptions = &types.TradeQueryOptions{ + StartTime: &startTime, + EndTime: &endTime, + LastTradeID: 0, + Limit: 50, + } + mockExchange = mocks.NewMockExchangeTradeHistoryService(ctrl) + ) + + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { + assert.Equal(t, startTime, *options.StartTime) + assert.Equal(t, endTime, *options.EndTime) + assert.Equal(t, uint64(0), options.LastTradeID) + assert.Equal(t, expOptions.Limit, options.Limit) + return queryTrades1, nil + }).Times(1) + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { + assert.Equal(t, queryTrades1[0].Time.Time(), *options.StartTime) + assert.Equal(t, endTime, *options.EndTime) + assert.Equal(t, queryTrades1[0].ID, options.LastTradeID) + assert.Equal(t, expOptions.Limit, options.Limit) + return queryTrades2, nil + }).Times(1) + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { + assert.Equal(t, queryTrades2[1].Time.Time(), *options.StartTime) + assert.Equal(t, endTime, *options.EndTime) + assert.Equal(t, queryTrades2[1].ID, options.LastTradeID) + assert.Equal(t, expOptions.Limit, options.Limit) + return emptyTrades, nil + }).Times(1) + + tradeBatchQuery := &TradeBatchQuery{ExchangeTradeHistoryService: mockExchange} + + resCh, errC := tradeBatchQuery.Query(ctx, expSymbol, expOptions) + wg := sync.WaitGroup{} + wg.Add(1) + rcvCount := 0 + go func() { + defer wg.Done() + for ch := range resCh { + assert.Equal(t, allRes[rcvCount], ch) + rcvCount++ + } + assert.NoError(t, <-errC) + }() + wg.Wait() + assert.Equal(t, rcvCount, len(allRes)) + }) + + t.Run("failed to call query trades", func(t *testing.T) { + var ( + expOptions = &types.TradeQueryOptions{ + StartTime: &startTime, + EndTime: &endTime, + LastTradeID: 0, + Limit: 50, + } + mockExchange = mocks.NewMockExchangeTradeHistoryService(ctrl) + unknownErr = errors.New("unknown err") + ) + + mockExchange.EXPECT().QueryTrades(ctx, expSymbol, expOptions).DoAndReturn( + func(ctx context.Context, symbol string, options *types.TradeQueryOptions) ([]types.Trade, error) { + assert.Equal(t, startTime, *options.StartTime) + assert.Equal(t, endTime, *options.EndTime) + assert.Equal(t, uint64(0), options.LastTradeID) + assert.Equal(t, expOptions.Limit, options.Limit) + return nil, unknownErr + }).Times(1) + + tradeBatchQuery := &TradeBatchQuery{ExchangeTradeHistoryService: mockExchange} + + resCh, errC := tradeBatchQuery.Query(ctx, expSymbol, expOptions) + wg := sync.WaitGroup{} + wg.Add(1) + rcvCount := 0 + go func() { + defer wg.Done() + for ch := range resCh { + assert.Equal(t, allRes[rcvCount], ch) + rcvCount++ + } + assert.Equal(t, 0, rcvCount) + assert.Equal(t, unknownErr, <-errC) + }() + wg.Wait() + assert.Equal(t, rcvCount, 0) + }) +}