diff --git a/hrp/internal/boomer/boomer.go b/hrp/internal/boomer/boomer.go index 2dc40056..cfbd17ea 100644 --- a/hrp/internal/boomer/boomer.go +++ b/hrp/internal/boomer/boomer.go @@ -458,10 +458,10 @@ func (b *Boomer) RecordFailure(requestType, name string, responseTime int64, exc // Start starts to run func (b *Boomer) Start(Args *Profile) error { - if b.masterRunner.isStarted() { + if b.masterRunner.isStarting() { return errors.New("already started") } - if b.masterRunner.getState() == StateStopping { + if b.masterRunner.isStopping() { return errors.New("Please wait for all workers to finish") } b.SetSpawnCount(Args.SpawnCount) @@ -473,7 +473,7 @@ func (b *Boomer) Start(Args *Profile) error { // ReBalance starts to rebalance load test func (b *Boomer) ReBalance(Args *Profile) error { - if !b.masterRunner.isStarted() { + if !b.masterRunner.isStarting() { return errors.New("no start") } b.SetSpawnCount(Args.SpawnCount) diff --git a/hrp/internal/boomer/runner.go b/hrp/internal/boomer/runner.go index 675f69b9..207c18dd 100644 --- a/hrp/internal/boomer/runner.go +++ b/hrp/internal/boomer/runner.go @@ -379,7 +379,7 @@ func (r *runner) spawnWorkers(spawnCount int64, spawnRate float64, quit chan boo log.Info().Msg("Quitting spawning workers") return default: - if r.isStarted() && r.controller.acquire() { + if r.isStarting() && r.controller.acquire() { // spawn workers with rate limit sleepTime := time.Duration(1000000/r.controller.getSpawnRate()) * time.Microsecond time.Sleep(sleepTime) @@ -436,7 +436,7 @@ func (r *runner) spawnWorkers(spawnCount int64, spawnRate float64, quit chan boo }) <-r.rebalance - if r.isStarted() { + if r.isStarting() { // rebalance spawn count r.controller.setSpawn(r.getSpawnCount(), r.getSpawnRate()) } @@ -528,7 +528,7 @@ func (r *runner) statsStart() { case <-ticker.C: r.reportStats() // close reportedChan and return if the last stats is reported successfully - if !r.isStarted() { + if !r.isStarting() && !r.isStopping() { close(r.reportedChan) log.Info().Msg("Quitting statsStart") return @@ -575,10 +575,14 @@ func (r *runner) updateState(state int32) { atomic.StoreInt32(&r.state, state) } -func (r *runner) isStarted() bool { +func (r *runner) isStarting() bool { return r.getState() == StateRunning || r.getState() == StateSpawning } +func (r *runner) isStopping() bool { + return r.getState() == StateStopping +} + type localRunner struct { runner @@ -618,6 +622,7 @@ func (r *localRunner) start() { defer func() { r.wgMu.Lock() // block concurrent waitgroup adds in GoAttach while stopping + r.updateState(StateStopping) close(r.stoppingChan) close(r.rebalance) r.wgMu.Unlock() @@ -625,7 +630,7 @@ func (r *localRunner) start() { // wait for goroutines before closing r.wg.Wait() - r.updateState(StateStopping) + close(r.doneChan) // wait until all stats are reported successfully <-r.reportedChan @@ -636,7 +641,6 @@ func (r *localRunner) start() { // output teardown r.outputOnStop() - close(r.doneChan) r.updateState(StateQuitting) }() @@ -647,7 +651,7 @@ func (r *localRunner) start() { } func (r *localRunner) stop() { - if r.runner.isStarted() { + if r.runner.isStarting() { r.runner.stop() } } @@ -898,6 +902,7 @@ func (r *workerRunner) start() { defer func() { r.wgMu.Lock() // block concurrent waitgroup adds in GoAttach while stopping + r.updateState(StateStopping) close(r.stoppingChan) close(r.rebalance) r.wgMu.Unlock() @@ -905,14 +910,12 @@ func (r *workerRunner) start() { // wait for goroutines before closing r.wg.Wait() - r.updateState(StateStopping) + close(r.doneChan) <-r.reportedChan r.reportTestResult() r.outputOnStop() - - close(r.doneChan) }() // start stats report @@ -922,7 +925,7 @@ func (r *workerRunner) start() { } func (r *workerRunner) stop() { - if r.isStarted() { + if r.isStarting() { r.runner.stop() } } @@ -1064,7 +1067,7 @@ func (r *masterRunner) clientListener() { break } workerInfo.setState(StateQuitting) - if r.isStarted() { + if r.isStarting() { if r.server.getClientsLength() > 0 { log.Warn().Str("worker id", workerInfo.ID).Msg("worker quited, ready to rebalance the load of each worker") err := r.rebalance() @@ -1092,19 +1095,18 @@ func (r *masterRunner) run() { } defer func() { - r.wgMu.Lock() // block concurrent waitgroup adds in GoAttach while stopping - close(r.stoppingChan) - r.wgMu.Unlock() + // disconnecting workers + close(r.server.disconnectedChan) - r.wg.Wait() + // waiting to close bidirectional stream + r.server.wg.Wait() + // close server r.server.close() - - close(r.doneChan) }() if r.autoStart { - r.goAttach(func() { + go func() { log.Info().Msg("auto start, waiting expected workers joined") var ticker = time.NewTicker(1 * time.Second) var tickerMaxWait = time.NewTicker(time.Duration(r.expectWorkersMaxWait) * time.Second) @@ -1129,7 +1131,7 @@ func (r *masterRunner) run() { os.Exit(1) } } - }) + }() } // listen and deal message from worker @@ -1143,7 +1145,7 @@ func (r *masterRunner) run() { func (r *masterRunner) start() error { numWorkers := r.server.getClientsLength() if numWorkers == 0 { - return errors.New("current workers: 0") + return errors.New("current available workers: 0") } // fetching testcase @@ -1205,30 +1207,40 @@ func (r *masterRunner) start() error { func (r *masterRunner) rebalance() error { numWorkers := r.server.getClientsLength() if numWorkers == 0 { - return errors.New("current workers: 0") + return errors.New("current available workers: 0") } workerProfile := &Profile{} if err := copier.Copy(workerProfile, r.profile); err != nil { log.Error().Err(err).Msg("copy workerProfile failed") return err } + + // spawn count + spawnCounts := builtin.SplitInteger(int(r.profile.SpawnCount), numWorkers) + + // spawn rate + spawnRate := workerProfile.SpawnRate / float64(numWorkers) + if spawnRate < 1 { + spawnRate = 1 + } + + // max RPS + maxRPSs := builtin.SplitInteger(int(workerProfile.MaxRPS), numWorkers) + cur := 0 - ints := builtin.SplitInteger(int(r.profile.SpawnCount), numWorkers) log.Info().Msg("send spawn data to worker") r.server.clients.Range(func(key, value interface{}) bool { if workerInfo, ok := value.(*WorkerNode); ok { if workerInfo.getState() == StateQuitting || workerInfo.getState() == StateMissing { return true } + if workerProfile.SpawnCount > 0 { - workerProfile.SpawnCount = int64(ints[cur]) - } - if workerProfile.SpawnRate > 0 { - workerProfile.SpawnRate = workerProfile.SpawnRate / float64(numWorkers) - } - if workerProfile.MaxRPS > 0 { - workerProfile.MaxRPS = workerProfile.MaxRPS / int64(numWorkers) + workerProfile.SpawnCount = int64(spawnCounts[cur]) } + workerProfile.MaxRPS = int64(maxRPSs[cur]) + workerProfile.SpawnRate = spawnRate + if workerInfo.getState() == StateInit { workerInfo.getStream() <- &messager.StreamResponse{ Type: "spawn", @@ -1270,7 +1282,7 @@ func (r *masterRunner) fetchTestCase() ([]byte, error) { } func (r *masterRunner) stop() error { - if r.isStarted() { + if r.isStarting() { r.updateState(StateStopping) r.server.sendBroadcasts(&genericMessage{Type: "stop", Data: map[string]int64{}}) return nil diff --git a/hrp/internal/boomer/server_grpc.go b/hrp/internal/boomer/server_grpc.go index 08323aa6..0b2943da 100644 --- a/hrp/internal/boomer/server_grpc.go +++ b/hrp/internal/boomer/server_grpc.go @@ -138,6 +138,8 @@ type grpcServer struct { fromWorker chan *genericMessage disconnectedChan chan bool shutdownChan chan bool + + wg sync.WaitGroup } var ( @@ -221,6 +223,7 @@ func newServer(masterHost string, masterPort int) (server *grpcServer) { fromWorker: make(chan *genericMessage, 100), disconnectedChan: make(chan bool), shutdownChan: make(chan bool), + wg: sync.WaitGroup{}, } return server } @@ -290,6 +293,8 @@ func (s *grpcServer) valid(token string) (isValid bool) { } func (s *grpcServer) BidirectionalStreamingMessage(srv messager.Message_BidirectionalStreamingMessageServer) error { + s.wg.Add(1) + defer s.wg.Done() token, ok := extractToken(srv.Context()) if !ok { return status.Error(codes.Unauthenticated, "missing token header") @@ -303,32 +308,34 @@ func (s *grpcServer) BidirectionalStreamingMessage(srv messager.Message_Bidirect go s.sendMsg(srv, token) FOR: for { - msg, err := srv.Recv() - if st, ok := status.FromError(err); ok { - switch st.Code() { - case codes.OK: - s.fromWorker <- newGenericMessage(msg.Type, msg.Data, msg.NodeID) - log.Info(). - Str("nodeID", msg.NodeID). - Str("type", msg.Type). - Interface("data", msg.Data). - Msg("receive data from worker") - case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded: - s.fromWorker <- newQuitMessage(token) - break FOR - default: - log.Error().Err(err).Msg("failed to get stream from client") - break FOR + select { + case <-srv.Context().Done(): + break FOR + case <-s.disconnectedChannel(): + break FOR + default: + msg, err := srv.Recv() + if st, ok := status.FromError(err); ok { + switch st.Code() { + case codes.OK: + s.fromWorker <- newGenericMessage(msg.Type, msg.Data, msg.NodeID) + log.Info(). + Str("nodeID", msg.NodeID). + Str("type", msg.Type). + Interface("data", msg.Data). + Msg("receive data from worker") + case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded: + s.fromWorker <- newQuitMessage(token) + break FOR + default: + log.Error().Err(err).Msg("failed to get stream from client") + break FOR + } } } } - // disconnected to worker - select { - case <-srv.Context().Done(): - return srv.Context().Err() - case <-s.disconnectedChan: - } - log.Warn().Str("worker id", token).Msg("worker quited") + + log.Info().Str("worker id", token).Msg("bidirectional stream closed") return nil } @@ -338,6 +345,8 @@ func (s *grpcServer) sendMsg(srv messager.Message_BidirectionalStreamingMessageS select { case <-srv.Context().Done(): return + case <-s.disconnectedChannel(): + return case res := <-stream: if s, ok := status.FromError(srv.Send(res)); ok { switch s.Code() { @@ -406,7 +415,6 @@ func (s *grpcServer) close() { ctx, cancel := context.WithTimeout(context.Background(), timeout) s.stopServer(ctx) cancel() - close(s.disconnectedChan) } func (s *grpcServer) recvChannel() chan *genericMessage {