package gadb import ( "fmt" "io" "net" "regexp" "strconv" "strings" "time" "github.com/pkg/errors" "github.com/rs/zerolog/log" "github.com/httprunner/httprunner/v5/code" ) var ErrConnBroken = errors.New("socket connection broken") var DefaultAdbReadTimeout time.Duration = 300 var regexDeviceOffline = regexp.MustCompile("device .* not found") type transport struct { sock net.Conn readTimeout time.Duration } // newTransport creates a new tcp socket connection func newTransport(address string, readTimeout ...time.Duration) (tp transport, err error) { if len(readTimeout) == 0 { readTimeout = []time.Duration{DefaultAdbReadTimeout} } tp = transport{ readTimeout: readTimeout[0], } tp.sock, err = net.Dial("tcp", address) if err == nil { // dial success return tp, nil } // connection refused if strings.Contains(err.Error(), "connect: connection refused") { err = errors.Wrap(code.DeviceConnectionError, err.Error()) return } // device offline if regexDeviceOffline.MatchString(err.Error()) { err = errors.Wrap(code.DeviceOfflineError, err.Error()) return } // other connection errors err = errors.Wrap(code.DeviceConnectionError, err.Error()) return } func (t transport) Send(command string) (err error) { msg := fmt.Sprintf("%04x%s", len(command), command) return _send(t.sock, []byte(msg)) } func (t transport) SendBytes(b []byte) (err error) { return _send(t.sock, b) } func (t transport) Conn() net.Conn { return t.sock } func (t transport) VerifyResponse() (err error) { var status string if status, err = t.ReadStringN(4); err != nil { return err } if status == "OKAY" { return nil } var sError string if sError, err = t.UnpackString(); err != nil { return err } log.Warn().Str("status", status).Str("err", sError). Msg("verify adb response failed") return errors.New(sError) } func (t transport) ReadStringAll() (s string, err error) { var raw []byte raw, err = t.ReadBytesAll() return string(raw), err } func (t transport) ReadBytesAll() (raw []byte, err error) { raw, err = io.ReadAll(t.sock) return } func (t transport) UnpackString() (s string, err error) { var raw []byte raw, err = t.UnpackBytes() return string(raw), err } func (t transport) UnpackBytes() (raw []byte, err error) { var length string if length, err = t.ReadStringN(4); err != nil { return nil, err } var size int64 if size, err = strconv.ParseInt(length, 16, 64); err != nil { return nil, err } raw, err = t.ReadBytesN(int(size)) return } func (t transport) ReadStringN(size int) (s string, err error) { var raw []byte if raw, err = t.ReadBytesN(size); err != nil { return "", err } return string(raw), nil } func (t transport) ReadBytesN(size int) (raw []byte, err error) { _ = t.sock.SetReadDeadline(time.Now().Add(time.Second * t.readTimeout)) return _readN(t.sock, size) } func (t transport) Close() (err error) { if t.sock == nil { return nil } _ = DisableTimeWait(t.sock.(*net.TCPConn)) return t.sock.Close() } func (t transport) SendWithCheck(command string) (err error) { if err = t.Send(command); err != nil { return err } return t.VerifyResponse() } func (t transport) CreateSyncTransport() (sTp syncTransport, err error) { if err = t.SendWithCheck("sync:"); err != nil { return syncTransport{}, err } sTp = newSyncTransport(t.sock, t.readTimeout) return } func _send(writer io.Writer, msg []byte) (err error) { for totalSent := 0; totalSent < len(msg); { var sent int if sent, err = writer.Write(msg[totalSent:]); err != nil { return err } if sent == 0 { return ErrConnBroken } totalSent += sent } return } func _readN(reader io.Reader, size int) (raw []byte, err error) { raw = make([]byte, 0, size) for len(raw) < size { buf := make([]byte, size-len(raw)) var n int if n, err = io.ReadFull(reader, buf); err != nil { return nil, err } if n == 0 { return nil, ErrConnBroken } raw = append(raw, buf...) } return }