fix: state machine

This commit is contained in:
徐聪
2022-07-14 19:56:27 +08:00
parent 8fa04335e9
commit 76e058d709
3 changed files with 78 additions and 58 deletions

View File

@@ -458,10 +458,10 @@ func (b *Boomer) RecordFailure(requestType, name string, responseTime int64, exc
// Start starts to run
func (b *Boomer) Start(Args *Profile) error {
if b.masterRunner.isStarted() {
if b.masterRunner.isStarting() {
return errors.New("already started")
}
if b.masterRunner.getState() == StateStopping {
if b.masterRunner.isStopping() {
return errors.New("Please wait for all workers to finish")
}
b.SetSpawnCount(Args.SpawnCount)
@@ -473,7 +473,7 @@ func (b *Boomer) Start(Args *Profile) error {
// ReBalance starts to rebalance load test
func (b *Boomer) ReBalance(Args *Profile) error {
if !b.masterRunner.isStarted() {
if !b.masterRunner.isStarting() {
return errors.New("no start")
}
b.SetSpawnCount(Args.SpawnCount)

View File

@@ -379,7 +379,7 @@ func (r *runner) spawnWorkers(spawnCount int64, spawnRate float64, quit chan boo
log.Info().Msg("Quitting spawning workers")
return
default:
if r.isStarted() && r.controller.acquire() {
if r.isStarting() && r.controller.acquire() {
// spawn workers with rate limit
sleepTime := time.Duration(1000000/r.controller.getSpawnRate()) * time.Microsecond
time.Sleep(sleepTime)
@@ -436,7 +436,7 @@ func (r *runner) spawnWorkers(spawnCount int64, spawnRate float64, quit chan boo
})
<-r.rebalance
if r.isStarted() {
if r.isStarting() {
// rebalance spawn count
r.controller.setSpawn(r.getSpawnCount(), r.getSpawnRate())
}
@@ -528,7 +528,7 @@ func (r *runner) statsStart() {
case <-ticker.C:
r.reportStats()
// close reportedChan and return if the last stats is reported successfully
if !r.isStarted() {
if !r.isStarting() && !r.isStopping() {
close(r.reportedChan)
log.Info().Msg("Quitting statsStart")
return
@@ -575,10 +575,14 @@ func (r *runner) updateState(state int32) {
atomic.StoreInt32(&r.state, state)
}
func (r *runner) isStarted() bool {
func (r *runner) isStarting() bool {
return r.getState() == StateRunning || r.getState() == StateSpawning
}
func (r *runner) isStopping() bool {
return r.getState() == StateStopping
}
type localRunner struct {
runner
@@ -618,6 +622,7 @@ func (r *localRunner) start() {
defer func() {
r.wgMu.Lock() // block concurrent waitgroup adds in GoAttach while stopping
r.updateState(StateStopping)
close(r.stoppingChan)
close(r.rebalance)
r.wgMu.Unlock()
@@ -625,7 +630,7 @@ func (r *localRunner) start() {
// wait for goroutines before closing
r.wg.Wait()
r.updateState(StateStopping)
close(r.doneChan)
// wait until all stats are reported successfully
<-r.reportedChan
@@ -636,7 +641,6 @@ func (r *localRunner) start() {
// output teardown
r.outputOnStop()
close(r.doneChan)
r.updateState(StateQuitting)
}()
@@ -647,7 +651,7 @@ func (r *localRunner) start() {
}
func (r *localRunner) stop() {
if r.runner.isStarted() {
if r.runner.isStarting() {
r.runner.stop()
}
}
@@ -898,6 +902,7 @@ func (r *workerRunner) start() {
defer func() {
r.wgMu.Lock() // block concurrent waitgroup adds in GoAttach while stopping
r.updateState(StateStopping)
close(r.stoppingChan)
close(r.rebalance)
r.wgMu.Unlock()
@@ -905,14 +910,12 @@ func (r *workerRunner) start() {
// wait for goroutines before closing
r.wg.Wait()
r.updateState(StateStopping)
close(r.doneChan)
<-r.reportedChan
r.reportTestResult()
r.outputOnStop()
close(r.doneChan)
}()
// start stats report
@@ -922,7 +925,7 @@ func (r *workerRunner) start() {
}
func (r *workerRunner) stop() {
if r.isStarted() {
if r.isStarting() {
r.runner.stop()
}
}
@@ -1064,7 +1067,7 @@ func (r *masterRunner) clientListener() {
break
}
workerInfo.setState(StateQuitting)
if r.isStarted() {
if r.isStarting() {
if r.server.getClientsLength() > 0 {
log.Warn().Str("worker id", workerInfo.ID).Msg("worker quited, ready to rebalance the load of each worker")
err := r.rebalance()
@@ -1092,19 +1095,18 @@ func (r *masterRunner) run() {
}
defer func() {
r.wgMu.Lock() // block concurrent waitgroup adds in GoAttach while stopping
close(r.stoppingChan)
r.wgMu.Unlock()
// disconnecting workers
close(r.server.disconnectedChan)
r.wg.Wait()
// waiting to close bidirectional stream
r.server.wg.Wait()
// close server
r.server.close()
close(r.doneChan)
}()
if r.autoStart {
r.goAttach(func() {
go func() {
log.Info().Msg("auto start, waiting expected workers joined")
var ticker = time.NewTicker(1 * time.Second)
var tickerMaxWait = time.NewTicker(time.Duration(r.expectWorkersMaxWait) * time.Second)
@@ -1129,7 +1131,7 @@ func (r *masterRunner) run() {
os.Exit(1)
}
}
})
}()
}
// listen and deal message from worker
@@ -1143,7 +1145,7 @@ func (r *masterRunner) run() {
func (r *masterRunner) start() error {
numWorkers := r.server.getClientsLength()
if numWorkers == 0 {
return errors.New("current workers: 0")
return errors.New("current available workers: 0")
}
// fetching testcase
@@ -1205,30 +1207,40 @@ func (r *masterRunner) start() error {
func (r *masterRunner) rebalance() error {
numWorkers := r.server.getClientsLength()
if numWorkers == 0 {
return errors.New("current workers: 0")
return errors.New("current available workers: 0")
}
workerProfile := &Profile{}
if err := copier.Copy(workerProfile, r.profile); err != nil {
log.Error().Err(err).Msg("copy workerProfile failed")
return err
}
// spawn count
spawnCounts := builtin.SplitInteger(int(r.profile.SpawnCount), numWorkers)
// spawn rate
spawnRate := workerProfile.SpawnRate / float64(numWorkers)
if spawnRate < 1 {
spawnRate = 1
}
// max RPS
maxRPSs := builtin.SplitInteger(int(workerProfile.MaxRPS), numWorkers)
cur := 0
ints := builtin.SplitInteger(int(r.profile.SpawnCount), numWorkers)
log.Info().Msg("send spawn data to worker")
r.server.clients.Range(func(key, value interface{}) bool {
if workerInfo, ok := value.(*WorkerNode); ok {
if workerInfo.getState() == StateQuitting || workerInfo.getState() == StateMissing {
return true
}
if workerProfile.SpawnCount > 0 {
workerProfile.SpawnCount = int64(ints[cur])
}
if workerProfile.SpawnRate > 0 {
workerProfile.SpawnRate = workerProfile.SpawnRate / float64(numWorkers)
}
if workerProfile.MaxRPS > 0 {
workerProfile.MaxRPS = workerProfile.MaxRPS / int64(numWorkers)
workerProfile.SpawnCount = int64(spawnCounts[cur])
}
workerProfile.MaxRPS = int64(maxRPSs[cur])
workerProfile.SpawnRate = spawnRate
if workerInfo.getState() == StateInit {
workerInfo.getStream() <- &messager.StreamResponse{
Type: "spawn",
@@ -1270,7 +1282,7 @@ func (r *masterRunner) fetchTestCase() ([]byte, error) {
}
func (r *masterRunner) stop() error {
if r.isStarted() {
if r.isStarting() {
r.updateState(StateStopping)
r.server.sendBroadcasts(&genericMessage{Type: "stop", Data: map[string]int64{}})
return nil

View File

@@ -138,6 +138,8 @@ type grpcServer struct {
fromWorker chan *genericMessage
disconnectedChan chan bool
shutdownChan chan bool
wg sync.WaitGroup
}
var (
@@ -221,6 +223,7 @@ func newServer(masterHost string, masterPort int) (server *grpcServer) {
fromWorker: make(chan *genericMessage, 100),
disconnectedChan: make(chan bool),
shutdownChan: make(chan bool),
wg: sync.WaitGroup{},
}
return server
}
@@ -290,6 +293,8 @@ func (s *grpcServer) valid(token string) (isValid bool) {
}
func (s *grpcServer) BidirectionalStreamingMessage(srv messager.Message_BidirectionalStreamingMessageServer) error {
s.wg.Add(1)
defer s.wg.Done()
token, ok := extractToken(srv.Context())
if !ok {
return status.Error(codes.Unauthenticated, "missing token header")
@@ -303,32 +308,34 @@ func (s *grpcServer) BidirectionalStreamingMessage(srv messager.Message_Bidirect
go s.sendMsg(srv, token)
FOR:
for {
msg, err := srv.Recv()
if st, ok := status.FromError(err); ok {
switch st.Code() {
case codes.OK:
s.fromWorker <- newGenericMessage(msg.Type, msg.Data, msg.NodeID)
log.Info().
Str("nodeID", msg.NodeID).
Str("type", msg.Type).
Interface("data", msg.Data).
Msg("receive data from worker")
case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded:
s.fromWorker <- newQuitMessage(token)
break FOR
default:
log.Error().Err(err).Msg("failed to get stream from client")
break FOR
select {
case <-srv.Context().Done():
break FOR
case <-s.disconnectedChannel():
break FOR
default:
msg, err := srv.Recv()
if st, ok := status.FromError(err); ok {
switch st.Code() {
case codes.OK:
s.fromWorker <- newGenericMessage(msg.Type, msg.Data, msg.NodeID)
log.Info().
Str("nodeID", msg.NodeID).
Str("type", msg.Type).
Interface("data", msg.Data).
Msg("receive data from worker")
case codes.Unavailable, codes.Canceled, codes.DeadlineExceeded:
s.fromWorker <- newQuitMessage(token)
break FOR
default:
log.Error().Err(err).Msg("failed to get stream from client")
break FOR
}
}
}
}
// disconnected to worker
select {
case <-srv.Context().Done():
return srv.Context().Err()
case <-s.disconnectedChan:
}
log.Warn().Str("worker id", token).Msg("worker quited")
log.Info().Str("worker id", token).Msg("bidirectional stream closed")
return nil
}
@@ -338,6 +345,8 @@ func (s *grpcServer) sendMsg(srv messager.Message_BidirectionalStreamingMessageS
select {
case <-srv.Context().Done():
return
case <-s.disconnectedChannel():
return
case res := <-stream:
if s, ok := status.FromError(srv.Send(res)); ok {
switch s.Code() {
@@ -406,7 +415,6 @@ func (s *grpcServer) close() {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
s.stopServer(ctx)
cancel()
close(s.disconnectedChan)
}
func (s *grpcServer) recvChannel() chan *genericMessage {