diff --git a/cmd/dashboard/controller/user.go b/cmd/dashboard/controller/user.go index b082eaa..16eb94f 100644 --- a/cmd/dashboard/controller/user.go +++ b/cmd/dashboard/controller/user.go @@ -8,6 +8,7 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/service/singleton" ) @@ -208,7 +209,7 @@ func batchBlockOnlineUser(c *gin.Context) (any, error) { return nil, err } - if err := singleton.BlockByIPs(list); err != nil { + if err := singleton.BlockByIPs(utils.Unique(list)); err != nil { return nil, newGormError("%v", err) } diff --git a/cmd/dashboard/controller/waf.go b/cmd/dashboard/controller/waf.go index a7dd375..96c3349 100644 --- a/cmd/dashboard/controller/waf.go +++ b/cmd/dashboard/controller/waf.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/nezhahq/nezha/model" + "github.com/nezhahq/nezha/pkg/utils" "github.com/nezhahq/nezha/service/singleton" ) @@ -68,7 +69,7 @@ func batchDeleteBlockedAddress(c *gin.Context) (any, error) { return nil, err } - if err := model.BatchClearIP(singleton.DB, list); err != nil { + if err := model.BatchClearIP(singleton.DB, utils.Unique(list)); err != nil { return nil, newGormError("%v", err) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index d3efd5a..98ab9bf 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -152,3 +152,15 @@ func MapValuesToSlice[Map ~map[K]V, K comparable, V any](m Map) []V { s := make([]V, 0, len(m)) return slices.AppendSeq(s, maps.Values(m)) } + +func Unique[T comparable](s []T) []T { + m := make(map[T]struct{}) + ret := make([]T, 0, len(s)) + for _, v := range s { + if _, ok := m[v]; !ok { + m[v] = struct{}{} + ret = append(ret, v) + } + } + return ret +} diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 34c6f8b..d9ee1ba 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -134,3 +134,49 @@ func TestBinaryToIPString(t *testing.T) { } } } + +func TestUnique(t *testing.T) { + cases := []struct { + input []string + output []string + }{ + { + input: []string{"a", "b", "c", "a", "b", "c"}, + output: []string{"a", "b", "c"}, + }, + { + input: []string{"a", "b", "c"}, + output: []string{"a", "b", "c"}, + }, + { + input: []string{"a", "a", "a"}, + output: []string{"a"}, + }, + { + input: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, + output: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}, + }, + { + input: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "a"}, + output: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i"}, + }, + { + input: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "a", "b", "c", "d", "e", "f", "g", "h", "i"}, + output: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i"}, + }, + { + input: []string{}, + output: []string{}, + }, + { + input: []string{"a"}, + output: []string{"a"}, + }, + } + + for _, c := range cases { + if !reflect.DeepEqual(Unique(c.input), c.output) { + t.Fatalf("Expected %v, but got %v", c.output, Unique(c.input)) + } + } +}