nezha/cmd/agent/main.go

291 lines
6.6 KiB
Go
Raw Normal View History

2019-12-05 09:36:58 -05:00
package main
import (
"context"
2021-01-16 01:11:51 -05:00
"crypto/tls"
"errors"
"flag"
2019-12-07 05:14:40 -05:00
"fmt"
2019-12-05 09:36:58 -05:00
"log"
"net"
"net/http"
2019-12-08 10:49:38 -05:00
"os"
"os/exec"
"strings"
2019-12-07 05:14:40 -05:00
"time"
2019-12-05 09:36:58 -05:00
2020-11-29 11:07:27 -05:00
"github.com/blang/semver"
"github.com/genkiroid/cert"
"github.com/go-ping/ping"
2020-12-12 12:31:22 -05:00
"github.com/p14yground/go-github-selfupdate/selfupdate"
2019-12-05 09:36:58 -05:00
"google.golang.org/grpc"
2019-12-05 10:42:20 -05:00
2021-03-20 02:53:10 -04:00
"github.com/naiba/nezha/cmd/agent/monitor"
2020-11-10 21:07:45 -05:00
"github.com/naiba/nezha/model"
"github.com/naiba/nezha/pkg/utils"
2020-11-10 21:07:45 -05:00
pb "github.com/naiba/nezha/proto"
"github.com/naiba/nezha/service/dao"
"github.com/naiba/nezha/service/rpc"
2019-12-05 09:36:58 -05:00
)
2019-12-08 10:49:38 -05:00
var (
2020-11-29 11:07:27 -05:00
server string
clientSecret string
2020-12-18 23:43:03 -05:00
version string
2019-12-08 10:49:38 -05:00
)
2020-11-30 01:24:00 -05:00
var (
client pb.NezhaServiceClient
ctx = context.Background()
delayWhenError = time.Second * 10 // Agent 重连间隔
updateCh = make(chan struct{}) // Agent 自动更新间隔
httpClient = &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
2020-11-30 01:24:00 -05:00
)
2020-11-29 11:07:27 -05:00
func doSelfUpdate() {
defer func() {
2020-11-30 00:49:31 -05:00
time.Sleep(time.Minute * 20)
2020-11-29 11:07:27 -05:00
updateCh <- struct{}{}
}()
v := semver.MustParse(version)
log.Println("Check update", v)
2020-11-29 11:07:27 -05:00
latest, err := selfupdate.UpdateSelf(v, "naiba/nezha")
if err != nil {
log.Println("Binary update failed:", err)
return
}
if latest.Version.Equals(v) {
// latest version is the same as current version. It means current binary is up to date.
log.Println("Current binary is the latest version", version)
} else {
log.Println("Successfully updated to version", latest.Version)
2020-12-20 00:03:03 -05:00
os.Exit(1)
2020-11-29 11:07:27 -05:00
}
}
func init() {
cert.TimeoutSeconds = 30
}
2019-12-05 09:36:58 -05:00
func main() {
2020-11-29 23:45:51 -05:00
// 来自于 GoReleaser 的版本号
dao.Version = version
2019-12-08 10:49:38 -05:00
var debug bool
flag.String("i", "", "unused 旧Agent配置兼容")
flag.BoolVar(&debug, "d", false, "允许不安全连接")
flag.StringVar(&server, "s", "localhost:5555", "管理面板RPC端口")
flag.StringVar(&clientSecret, "p", "", "Agent连接Secret")
flag.Parse()
2019-12-09 03:02:49 -05:00
dao.Conf = &model.Config{
Debug: debug,
2019-12-05 09:36:58 -05:00
}
if server == "" || clientSecret == "" {
flag.Usage()
return
}
run()
}
func run() {
2019-12-09 03:02:49 -05:00
auth := rpc.AuthHandler{
ClientSecret: clientSecret,
2019-12-05 09:36:58 -05:00
}
2019-12-10 04:57:57 -05:00
// 上报服务器信息
go reportState()
2021-03-20 11:50:16 -04:00
// 更新IP信息
go monitor.UpdateIP()
2019-12-10 04:57:57 -05:00
if version != "" {
go func() {
for range updateCh {
go doSelfUpdate()
}
}()
updateCh <- struct{}{}
}
2020-11-29 11:07:27 -05:00
2019-12-10 04:57:57 -05:00
var err error
var conn *grpc.ClientConn
2020-02-09 09:04:59 -05:00
retry := func() {
log.Println("Error to close connection ...")
if conn != nil {
conn.Close()
}
time.Sleep(delayWhenError)
log.Println("Try to reconnect ...")
}
2019-12-09 05:14:31 -05:00
for {
2019-12-09 10:45:23 -05:00
conn, err = grpc.Dial(server, grpc.WithInsecure(), grpc.WithPerRPCCredentials(&auth))
2019-12-09 03:02:49 -05:00
if err != nil {
log.Printf("grpc.Dial err: %v", err)
2019-12-09 05:14:31 -05:00
retry()
2019-12-09 03:02:49 -05:00
continue
}
client = pb.NewNezhaServiceClient(conn)
// 第一步注册
_, err = client.ReportSystemInfo(ctx, monitor.GetHost().PB())
2019-12-09 05:14:31 -05:00
if err != nil {
log.Printf("client.ReportSystemInfo err: %v", err)
2019-12-09 05:14:31 -05:00
retry()
continue
}
// 执行 Task
tasks, err := client.RequestTask(ctx, monitor.GetHost().PB())
2019-12-09 03:02:49 -05:00
if err != nil {
log.Printf("client.RequestTask err: %v", err)
2019-12-09 05:14:31 -05:00
retry()
2019-12-09 03:02:49 -05:00
continue
}
err = receiveTasks(tasks)
log.Printf("receiveTasks exit to main: %v", err)
2019-12-09 05:14:31 -05:00
retry()
2019-12-07 05:14:40 -05:00
}
2019-12-09 03:02:49 -05:00
}
2019-12-07 05:14:40 -05:00
func receiveTasks(tasks pb.NezhaService_RequestTaskClient) error {
2019-12-09 03:02:49 -05:00
var err error
defer log.Printf("receiveTasks exit %v => %v", time.Now(), err)
2019-12-09 03:02:49 -05:00
for {
var task *pb.Task
task, err = tasks.Recv()
2019-12-09 03:02:49 -05:00
if err != nil {
return err
}
go doTask(task)
}
}
func doTask(task *pb.Task) {
var result pb.TaskResult
result.Id = task.GetId()
result.Type = task.GetType()
switch task.GetType() {
case model.TaskTypeHTTPGET:
start := time.Now()
resp, err := httpClient.Get(task.GetData())
if err == nil {
// 检查 HTTP Response 状态
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
if resp.StatusCode > 399 || resp.StatusCode < 200 {
err = errors.New("\n应用错误" + resp.Status)
}
}
if err == nil {
// 检查 SSL 证书信息
if strings.HasPrefix(task.GetData(), "https://") {
c := cert.NewCert(task.GetData()[8:])
if c.Error != "" {
result.Data = "SSL证书错误" + c.Error
} else {
result.Data = c.Issuer + "|" + c.NotAfter
result.Successful = true
}
} else {
result.Successful = true
}
} else {
// HTTP 请求失败
result.Data = err.Error()
}
case model.TaskTypeICMPPing:
pinger, err := ping.NewPinger(task.GetData())
if err == nil {
2021-01-18 00:45:06 -05:00
pinger.SetPrivileged(true)
pinger.Count = 10
pinger.Timeout = time.Second * 20
err = pinger.Run() // Blocks until finished.
}
if err == nil {
result.Delay = float32(pinger.Statistics().AvgRtt.Microseconds()) / 1000.0
result.Successful = true
} else {
result.Data = err.Error()
}
case model.TaskTypeTCPPing:
start := time.Now()
conn, err := net.DialTimeout("tcp", task.GetData(), time.Second*10)
if err == nil {
conn.Write([]byte("ping\n"))
conn.Close()
result.Delay = float32(time.Since(start).Microseconds()) / 1000.0
result.Successful = true
} else {
result.Data = err.Error()
2019-12-09 03:02:49 -05:00
}
case model.TaskTypeCommand:
startedAt := time.Now()
var cmd *exec.Cmd
2021-01-28 21:40:57 -05:00
var endCh = make(chan struct{})
2021-01-28 22:59:35 -05:00
pg, err := utils.NewProcessExitGroup()
if err != nil {
// 进程组创建失败,直接退出
result.Data = err.Error()
client.ReportTask(ctx, &result)
return
}
2021-01-28 21:40:57 -05:00
timeout := time.NewTimer(time.Hour * 2)
if utils.IsWindows() {
cmd = exec.Command("cmd", "/c", task.GetData())
} else {
cmd = exec.Command("sh", "-c", task.GetData())
}
2021-01-28 22:59:35 -05:00
pg.AddProcess(cmd)
2021-01-28 21:40:57 -05:00
go func() {
select {
case <-timeout.C:
result.Data = "任务执行超时\n"
close(endCh)
2021-01-29 01:29:31 -05:00
pg.Dispose()
2021-01-28 21:40:57 -05:00
case <-endCh:
2021-01-29 01:29:31 -05:00
timeout.Stop()
}
2021-01-28 21:40:57 -05:00
}()
output, err := cmd.Output()
if err != nil {
result.Data += fmt.Sprintf("%s\n%s", string(output), err.Error())
} else {
close(endCh)
result.Data = string(output)
result.Successful = true
}
result.Delay = float32(time.Since(startedAt).Seconds())
default:
log.Printf("Unknown action: %v", task)
2019-12-07 05:14:40 -05:00
}
client.ReportTask(ctx, &result)
2019-12-09 03:02:49 -05:00
}
2019-12-05 09:36:58 -05:00
2019-12-09 03:02:49 -05:00
func reportState() {
var lastReportHostInfo time.Time
2019-12-09 03:02:49 -05:00
var err error
defer log.Printf("reportState exit %v => %v", time.Now(), err)
2019-12-09 03:02:49 -05:00
for {
if client != nil {
2019-12-09 10:45:23 -05:00
monitor.TrackNetworkSpeed()
_, err = client.ReportSystemState(ctx, monitor.GetState(dao.ReportDelay).PB())
2019-12-09 05:14:31 -05:00
if err != nil {
log.Printf("reportState error %v", err)
time.Sleep(delayWhenError)
}
if lastReportHostInfo.Before(time.Now().Add(-10 * time.Minute)) {
lastReportHostInfo = time.Now()
client.ReportSystemInfo(ctx, monitor.GetHost().PB())
}
2019-12-05 10:42:20 -05:00
}
}
2019-12-05 09:36:58 -05:00
}