diff --git a/server/client/request.go b/server/client/request.go index f5f19967..9489404c 100644 --- a/server/client/request.go +++ b/server/client/request.go @@ -145,6 +145,26 @@ func WatchDirect(ctx context.Context, grpccon *grpc.ClientConn, wshost, appid, r return connectToRoom(ctx, accinfo, res, warn) } +// WatchByNumber : 部屋番号で観戦入室 +func WatchByNumber(ctx context.Context, accinfo *AccessInfo, number int32, query *Query, warn func(error)) (*Room, *Connection, error) { + var q []lobby.PropQueries + if query != nil { + q = []lobby.PropQueries(*query) + } + param := lobby.JoinParam{ + Queries: q, + ClientInfo: &pb.ClientInfo{Id: accinfo.UserId}, + EncMACKey: accinfo.EncMACKey, + } + + res, err := lobbyRequest(ctx, accinfo, fmt.Sprintf("/rooms/watch/number/%d", number), param) + if err != nil { + return nil, nil, xerrors.Errorf("lobbyRequest: %w", err) + } + + return connectToRoom(ctx, accinfo, res.Room, warn) +} + // Search : 部屋を検索する func Search(ctx context.Context, accinfo *AccessInfo, param *lobby.SearchParam) ([]*pb.RoomInfo, error) { res, err := lobbyRequest(ctx, accinfo, "/rooms/search", param) diff --git a/server/cmd/wsnet2-bot/cmd/root.go b/server/cmd/wsnet2-bot/cmd/root.go index 89a7ed3d..d0ca1abf 100644 --- a/server/cmd/wsnet2-bot/cmd/root.go +++ b/server/cmd/wsnet2-bot/cmd/root.go @@ -174,6 +174,22 @@ func joinRoom(ctx context.Context, player, roomId string, query *client.Query) ( return client.Join(ctx, accinfo, roomId, query, cinfo, nil) } +// joinByNumber joins the player to a room specified by the number +func joinByNumber(ctx context.Context, player string, number int32, query *client.Query) (*client.Room, *client.Connection, error) { + accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, player) + if err != nil { + return nil, nil, err + } + + if query == nil { + query = client.NewQuery() + } + + cinfo := &pb.ClientInfo{Id: player} + + return client.JoinByNumber(ctx, accinfo, number, query, cinfo, nil) +} + // joinRandom joins the player to a room randomly func joinRandom(ctx context.Context, player string, group uint32, query *client.Query) (*client.Room, *client.Connection, error) { accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, player) @@ -200,6 +216,20 @@ func watchRoom(ctx context.Context, watcher, roomId string, query *client.Query) return client.Watch(ctx, accinfo, roomId, query, nil) } +// watchByNumber joins the watcher to a room specified by the number +func watchByNumber(ctx context.Context, watcher string, number int32, query *client.Query) (*client.Room, *client.Connection, error) { + accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, watcher) + if err != nil { + return nil, nil, err + } + + if query == nil { + query = client.NewQuery() + } + + return client.WatchByNumber(ctx, accinfo, number, query, nil) +} + // searchCurrent search current rooms func searchCurrent(ctx context.Context, cid string) ([]*pb.RoomInfo, error) { accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, cid) diff --git a/server/cmd/wsnet2-bot/cmd/scenario.go b/server/cmd/wsnet2-bot/cmd/scenario.go index 0a422a29..1da87440 100644 --- a/server/cmd/wsnet2-bot/cmd/scenario.go +++ b/server/cmd/wsnet2-bot/cmd/scenario.go @@ -257,7 +257,7 @@ func scenarioJoinRoom(ctx context.Context) error { clearEventBuffer(conn) // 正常入室 - _, p2, err := joinRoom(ctx, "joinroom_player2", room.Id, nil) + _, p2, err := joinByNumber(ctx, "joinroom_player2", *room.Number, nil) if err != nil { return fmt.Errorf("join-room: player2: %w", err) } @@ -286,6 +286,14 @@ func scenarioJoinRoom(ctx context.Context) error { discardEvents(w1) defer cleanupConn(ctx, w1) + _, w2, err := watchByNumber(ctx, "joinroom_watcher2", *room.Number, nil) + if err != nil { + return fmt.Errorf("join-room: watcher2: %w", err) + } + logger.Infof("join-room: watcher2 ok") + discardEvents(w2) + defer cleanupConn(ctx, w2) + clearEventBuffer(conn) // MaxPlayerを+2増やしwatchable=falseに @@ -310,12 +318,12 @@ func scenarioJoinRoom(ctx context.Context) error { defer cleanupConn(ctx, p4) // 観戦はエラー - _, w2, err := watchRoom(ctx, "joinroom_watcher2", room.Id, nil) + _, w3, err := watchRoom(ctx, "joinroom_watcher3", room.Id, nil) if !errors.Is(err, client.ErrNoRoomFound) { - cleanupConn(ctx, w2) - return fmt.Errorf("join-room: watcher2 wants NoRoomFound: %v", err) + cleanupConn(ctx, w3) + return fmt.Errorf("join-room: watcher3 wants NoRoomFound: %v", err) } - logger.Infof("join-room: watcher2 ok (no room found)") + logger.Infof("join-room: watcher3 ok (no room found)") clearEventBuffer(conn) diff --git a/server/game/repository.go b/server/game/repository.go index 3c692b69..c39b1e2d 100644 --- a/server/game/repository.go +++ b/server/game/repository.go @@ -10,6 +10,7 @@ import ( "math/big" "math/rand" "reflect" + "regexp" "strings" "sync" "time" @@ -26,6 +27,8 @@ import ( const ( // RoomID文字列長 lenId = 16 + + idPattern = "^[0-9a-f]+$" ) var ( @@ -34,12 +37,16 @@ var ( roomHistoryInsertQuery string randsrc *rand.Rand + + rerid *regexp.Regexp ) func init() { initQueries() seed, _ := crand.Int(crand.Reader, big.NewInt(math.MaxInt64)) randsrc = rand.New(rand.NewSource(seed.Int64())) + + rerid = regexp.MustCompile(idPattern) } func dbCols(t reflect.Type) []string { @@ -82,6 +89,10 @@ func RandomHex(n int) string { return hex.EncodeToString(b) } +func IsValidRoomId(id string) bool { + return rerid.Match([]byte(id)) +} + type Repository struct { hostId uint32 diff --git a/server/game/repository_test.go b/server/game/repository_test.go index 627feddb..778e5c71 100644 --- a/server/game/repository_test.go +++ b/server/game/repository_test.go @@ -37,6 +37,20 @@ func TestQueries(t *testing.T) { } } +func TestIsValidRoomId(t *testing.T) { + tests := map[string]bool{ + "123456789abcdef": true, + "123456789ABCDEF": false, + "": false, + } + + for id, valid := range tests { + if IsValidRoomId(id) != valid { + t.Errorf("IsValidRoomId(%v) wants %v", id, valid) + } + } +} + func newDbMock(t *testing.T) (*sqlx.DB, sqlmock.Sqlmock) { db, mock, err := sqlmock.New() if err != nil { diff --git a/server/game/service/websocket.go b/server/game/service/websocket.go index 32101795..4cf724ce 100644 --- a/server/game/service/websocket.go +++ b/server/game/service/websocket.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/go-chi/chi/v5" "github.com/shiguredo/websocket" "golang.org/x/xerrors" @@ -68,8 +67,8 @@ func (sv *GameService) serveWebSocket(ctx context.Context) <-chan error { } ws := &WSHandler{sv} - r := chi.NewMux() - r.Get("/room/{id:[0-9a-f]+}", ws.HandleRoom) + r := http.NewServeMux() + r.HandleFunc("GET /room/{id}", ws.HandleRoom) sv.wsURLFormat = fmt.Sprintf("%s://%s:%d/room/%%s", scheme, sv.conf.PublicName, sv.conf.WebsocketPort) @@ -87,7 +86,12 @@ func (sv *GameService) serveWebSocket(ctx context.Context) <-chan error { } func (s *WSHandler) HandleRoom(w http.ResponseWriter, r *http.Request) { - roomId := chi.URLParam(r, "id") + roomId := r.PathValue("id") + if !game.IsValidRoomId(roomId) { + http.Error(w, "Not Found", http.StatusNotFound) + return + } + appId := r.Header.Get("Wsnet2-App") clientId := r.Header.Get("Wsnet2-User") logger := log.GetLoggerWith( diff --git a/server/go.mod b/server/go.mod index 0b381ae6..76ea7bfc 100644 --- a/server/go.mod +++ b/server/go.mod @@ -4,7 +4,6 @@ go 1.22.0 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 - github.com/go-chi/chi/v5 v5.0.12 github.com/go-sql-driver/mysql v1.7.1 github.com/google/go-cmp v0.6.0 github.com/jmoiron/sqlx v1.3.5 diff --git a/server/go.sum b/server/go.sum index 50dab919..13e4a0be 100644 --- a/server/go.sum +++ b/server/go.sum @@ -3,8 +3,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= -github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= diff --git a/server/hub/service/websocket.go b/server/hub/service/websocket.go index 356366a2..ac7b1941 100644 --- a/server/hub/service/websocket.go +++ b/server/hub/service/websocket.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/go-chi/chi/v5" "github.com/shiguredo/websocket" "golang.org/x/xerrors" @@ -68,8 +67,8 @@ func (sv *HubService) serveWebSocket(ctx context.Context) <-chan error { } ws := &WSHandler{sv} - r := chi.NewMux() - r.Get("/room/{id:[0-9a-f]+}", ws.HandleRoom) + r := http.NewServeMux() + r.HandleFunc("GET /room/{id}", ws.HandleRoom) sv.wsURLFormat = fmt.Sprintf("%s://%s:%d/room/%%s", scheme, sv.conf.PublicName, sv.conf.WebsocketPort) @@ -87,7 +86,12 @@ func (sv *HubService) serveWebSocket(ctx context.Context) <-chan error { } func (s *WSHandler) HandleRoom(w http.ResponseWriter, r *http.Request) { - roomId := chi.URLParam(r, "id") + roomId := r.PathValue("id") + if !game.IsValidRoomId(roomId) { + http.Error(w, "Not Found", http.StatusNotFound) + return + } + appId := r.Header.Get("Wsnet2-App") clientId := r.Header.Get("Wsnet2-User") logger := log.GetLoggerWith( diff --git a/server/lobby/service/api.go b/server/lobby/service/api.go index 5e756256..320358ce 100644 --- a/server/lobby/service/api.go +++ b/server/lobby/service/api.go @@ -9,11 +9,11 @@ import ( "net" "net/http" "os" + "regexp" "strconv" "strings" "time" - "github.com/go-chi/chi/v5" "github.com/vmihailenco/msgpack/v5" "golang.org/x/xerrors" @@ -61,7 +61,7 @@ func (sv *LobbyService) serveAPI(ctx context.Context) <-chan error { } } - r := chi.NewMux() + r := http.NewServeMux() sv.registerRoutes(r) errCh <- http.Serve(listener, r) @@ -75,21 +75,21 @@ func handleHealth(w http.ResponseWriter, r *http.Request) { w.Write([]byte("wsnet2 works\n")) } -func (sv *LobbyService) registerRoutes(r chi.Router) { - r.Get("/health", handleHealth) - r.Get("/health/", handleHealth) - - r.Post("/rooms", sv.handleCreateRoom) - r.Post("/rooms/join/id/{roomId}", sv.handleJoinRoom) - r.Post("/rooms/join/number/{roomNumber:[0-9]+}", sv.handleJoinRoomByNumber) - r.Post("/rooms/join/random/{searchGroup:[0-9]+}", sv.handleJoinRoomAtRandom) - r.Post("/rooms/search", sv.handleSearchRooms) - r.Post("/rooms/search/ids", sv.handleSearchByIds) - r.Post("/rooms/search/numbers", sv.handleSearchByNumbers) - r.Post("/rooms/search/current", sv.handleSearchCurrentRooms) - r.Post("/rooms/watch/id/{roomId}", sv.handleWatchRoom) - r.Post("/rooms/watch/number/{roomNumber:[0-9]+}", sv.handleWatchRoomByNumber) - r.Post("/_admin/kick", sv.handleAdminKick) +func (sv *LobbyService) registerRoutes(r *http.ServeMux) { + r.HandleFunc("GET /health", handleHealth) + r.HandleFunc("GET /health/{$}", handleHealth) + + r.HandleFunc("POST /rooms", sv.handleCreateRoom) + r.HandleFunc("POST /rooms/join/id/{roomId}", sv.handleJoinRoom) + r.HandleFunc("POST /rooms/join/number/{roomNumber}", sv.handleJoinRoomByNumber) + r.HandleFunc("POST /rooms/join/random/{searchGroup}", sv.handleJoinRoomAtRandom) + r.HandleFunc("POST /rooms/search", sv.handleSearchRooms) + r.HandleFunc("POST /rooms/search/ids", sv.handleSearchByIds) + r.HandleFunc("POST /rooms/search/numbers", sv.handleSearchByNumbers) + r.HandleFunc("POST /rooms/search/current", sv.handleSearchCurrentRooms) + r.HandleFunc("POST /rooms/watch/id/{roomId}", sv.handleWatchRoom) + r.HandleFunc("POST /rooms/watch/number/{roomNumber}", sv.handleWatchRoomByNumber) + r.HandleFunc("POST /_admin/kick", sv.handleAdminKick) } type header struct { @@ -126,7 +126,9 @@ func prepareLogger(handler string, hdr header, r *http.Request) log.Logger { log.KeyRequestedAt, float64(time.Now().UnixMilli())/1000, log.KeyApp, hdr.appId, log.KeyClient, hdr.userId, - log.KeyRemoteAddr, raddr) + log.KeyRemoteAddr, raddr, + log.KeyPath, r.URL.Path, + ) if err != nil { l.Errorf("SplitHostPort: %v", err) } @@ -245,37 +247,33 @@ func (sv *LobbyService) handleCreateRoom(w http.ResponseWriter, r *http.Request) renderJoinedRoomResponse(w, room, logger) } +var ( + idRegexp = regexp.MustCompile("^[0-9a-f]+$") +) + type JoinVars struct { - ctx *chi.Context + r *http.Request } -func NewJoinVars(r *http.Request) *JoinVars { - return &JoinVars{ - ctx: chi.RouteContext(r.Context()), - } +func NewJoinVars(r *http.Request) JoinVars { + return JoinVars{r: r} } -func (vars JoinVars) roomId() string { - id := vars.ctx.URLParam("roomId") - return id +func (vars JoinVars) roomId() (string, bool) { + id := vars.r.PathValue("roomId") + return id, idRegexp.MatchString(id) } -func (vars JoinVars) roomNumber() (number int32) { - v := vars.ctx.URLParam("roomNumber") - if v != "" { - n, _ := strconv.ParseInt(v, 10, 32) - number = int32(n) - } - return number +func (vars JoinVars) roomNumber() (int32, bool) { + v := vars.r.PathValue("roomNumber") + n, err := strconv.ParseInt(v, 10, 32) + return int32(n), err == nil && n > 0 } -func (vars JoinVars) searchGroup() (sg uint32) { - v := vars.ctx.URLParam("searchGroup") - if v != "" { - n, _ := strconv.ParseInt(v, 10, 32) - sg = uint32(n) - } - return sg +func (vars JoinVars) searchGroup() (uint32, bool) { + v := vars.r.PathValue("searchGroup") + n, err := strconv.ParseInt(v, 10, 32) + return uint32(n), err == nil && n >= 0 } func (sv *LobbyService) handleJoinRoom(w http.ResponseWriter, r *http.Request) { @@ -306,8 +304,8 @@ func (sv *LobbyService) handleJoinRoom(w http.ResponseWriter, r *http.Request) { } vars := NewJoinVars(r) - roomId := vars.roomId() - if roomId == "" { + roomId, ok := vars.roomId() + if !ok { renderErrorResponse( w, "Invalid room id", http.StatusBadRequest, xerrors.Errorf("Invalid room id"), logger) return @@ -351,10 +349,10 @@ func (sv *LobbyService) handleJoinRoomByNumber(w http.ResponseWriter, r *http.Re } vars := NewJoinVars(r) - roomNumber := vars.roomNumber() - if roomNumber == 0 { + roomNumber, ok := vars.roomNumber() + if !ok { renderErrorResponse( - w, "Invalid room number", http.StatusBadRequest, xerrors.Errorf("Invalid room number: 0"), logger) + w, "Invalid room number", http.StatusBadRequest, xerrors.Errorf("Invalid room number"), logger) return } logger = logger.With(log.KeyRoomNumber, roomNumber) @@ -396,7 +394,13 @@ func (sv *LobbyService) handleJoinRoomAtRandom(w http.ResponseWriter, r *http.Re } vars := NewJoinVars(r) - searchGroup := vars.searchGroup() + searchGroup, ok := vars.searchGroup() + if !ok { + renderErrorResponse( + w, "Failed to join room", http.StatusBadRequest, xerrors.Errorf("Invalid search group"), logger) + return + } + logger = logger.With(log.KeySearchGroup, searchGroup) room, err := sv.roomService.JoinAtRandom(ctx, h.appId, searchGroup, param.Queries, param.ClientInfo, macKey, logger) @@ -552,8 +556,8 @@ func (sv *LobbyService) handleWatchRoom(w http.ResponseWriter, r *http.Request) } vars := NewJoinVars(r) - roomId := vars.roomId() - if roomId == "" { + roomId, ok := vars.roomId() + if !ok { renderErrorResponse( w, "Invalid room id", http.StatusBadRequest, xerrors.Errorf("Invalid room id"), logger) return @@ -597,10 +601,10 @@ func (sv *LobbyService) handleWatchRoomByNumber(w http.ResponseWriter, r *http.R } vars := NewJoinVars(r) - roomNumber := vars.roomNumber() - if roomNumber == 0 { + roomNumber, ok := vars.roomNumber() + if !ok { renderErrorResponse( - w, "Invalid room number", http.StatusBadRequest, xerrors.Errorf("Invalid room number: 0"), logger) + w, "Invalid room number", http.StatusBadRequest, xerrors.Errorf("Invalid room number"), logger) return } logger = logger.With(log.KeyRoomNumber, roomNumber) diff --git a/server/log/log.go b/server/log/log.go index e4b7fdc3..94cb9212 100644 --- a/server/log/log.go +++ b/server/log/log.go @@ -54,6 +54,8 @@ const ( KeyRemoteAddr = "remoteAddr" // Requested at (unix timestamp, float64) KeyRequestedAt = "requestedAt" + // URL Path + KeyPath = "path" // Room ID KeyRoom = "room" // Room count