Skip to content

Commit

Permalink
Use http.ServeMux instead of chi.Mux (#86)
Browse files Browse the repository at this point in the history
Go 1.22から、標準net/httpでもHTTPメソッドの指定とパスパラメータの抽出ができるようになりました。

今までchi(それ以前はgorilla/mux)を使ってやっていましたが、wsnet2ではそんなに高度なことはやっていないので、標準net/httpに書き換えて依存を減らしてみました。

ただ、パスパラメータ部分の正規表現マッチングはnet/httpではできないので、ハンドラ側でチェックする形に変更しました。

これによりマッチしないときのレスポンスが404から400に変わるケースがありますが、正しいクライアントならそもそも起こらないので問題ないはずです。
  • Loading branch information
makiuchi-d authored Apr 2, 2024
2 parents 2821274 + adc9567 commit 1cbc278
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 67 deletions.
20 changes: 20 additions & 0 deletions server/client/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,26 @@ func WatchDirect(ctx context.Context, grpccon *grpc.ClientConn, wshost, appid, r
return connectToRoom(ctx, accinfo, res, warn)
}

// WatchByNumber : 部屋番号で観戦入室
func WatchByNumber(ctx context.Context, accinfo *AccessInfo, number int32, query *Query, warn func(error)) (*Room, *Connection, error) {
var q []lobby.PropQueries
if query != nil {
q = []lobby.PropQueries(*query)
}
param := lobby.JoinParam{
Queries: q,
ClientInfo: &pb.ClientInfo{Id: accinfo.UserId},
EncMACKey: accinfo.EncMACKey,
}

res, err := lobbyRequest(ctx, accinfo, fmt.Sprintf("/rooms/watch/number/%d", number), param)
if err != nil {
return nil, nil, xerrors.Errorf("lobbyRequest: %w", err)
}

return connectToRoom(ctx, accinfo, res.Room, warn)
}

// Search : 部屋を検索する
func Search(ctx context.Context, accinfo *AccessInfo, param *lobby.SearchParam) ([]*pb.RoomInfo, error) {
res, err := lobbyRequest(ctx, accinfo, "/rooms/search", param)
Expand Down
30 changes: 30 additions & 0 deletions server/cmd/wsnet2-bot/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,22 @@ func joinRoom(ctx context.Context, player, roomId string, query *client.Query) (
return client.Join(ctx, accinfo, roomId, query, cinfo, nil)
}

// joinByNumber joins the player to a room specified by the number
func joinByNumber(ctx context.Context, player string, number int32, query *client.Query) (*client.Room, *client.Connection, error) {
accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, player)
if err != nil {
return nil, nil, err
}

if query == nil {
query = client.NewQuery()
}

cinfo := &pb.ClientInfo{Id: player}

return client.JoinByNumber(ctx, accinfo, number, query, cinfo, nil)
}

// joinRandom joins the player to a room randomly
func joinRandom(ctx context.Context, player string, group uint32, query *client.Query) (*client.Room, *client.Connection, error) {
accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, player)
Expand All @@ -200,6 +216,20 @@ func watchRoom(ctx context.Context, watcher, roomId string, query *client.Query)
return client.Watch(ctx, accinfo, roomId, query, nil)
}

// watchByNumber joins the watcher to a room specified by the number
func watchByNumber(ctx context.Context, watcher string, number int32, query *client.Query) (*client.Room, *client.Connection, error) {
accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, watcher)
if err != nil {
return nil, nil, err
}

if query == nil {
query = client.NewQuery()
}

return client.WatchByNumber(ctx, accinfo, number, query, nil)
}

// searchCurrent search current rooms
func searchCurrent(ctx context.Context, cid string) ([]*pb.RoomInfo, error) {
accinfo, err := client.GenAccessInfo(lobbyURL, appId, appKey, cid)
Expand Down
18 changes: 13 additions & 5 deletions server/cmd/wsnet2-bot/cmd/scenario.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func scenarioJoinRoom(ctx context.Context) error {
clearEventBuffer(conn)

// 正常入室
_, p2, err := joinRoom(ctx, "joinroom_player2", room.Id, nil)
_, p2, err := joinByNumber(ctx, "joinroom_player2", *room.Number, nil)
if err != nil {
return fmt.Errorf("join-room: player2: %w", err)
}
Expand Down Expand Up @@ -286,6 +286,14 @@ func scenarioJoinRoom(ctx context.Context) error {
discardEvents(w1)
defer cleanupConn(ctx, w1)

_, w2, err := watchByNumber(ctx, "joinroom_watcher2", *room.Number, nil)
if err != nil {
return fmt.Errorf("join-room: watcher2: %w", err)
}
logger.Infof("join-room: watcher2 ok")
discardEvents(w2)
defer cleanupConn(ctx, w2)

clearEventBuffer(conn)

// MaxPlayerを+2増やしwatchable=falseに
Expand All @@ -310,12 +318,12 @@ func scenarioJoinRoom(ctx context.Context) error {
defer cleanupConn(ctx, p4)

// 観戦はエラー
_, w2, err := watchRoom(ctx, "joinroom_watcher2", room.Id, nil)
_, w3, err := watchRoom(ctx, "joinroom_watcher3", room.Id, nil)
if !errors.Is(err, client.ErrNoRoomFound) {
cleanupConn(ctx, w2)
return fmt.Errorf("join-room: watcher2 wants NoRoomFound: %v", err)
cleanupConn(ctx, w3)
return fmt.Errorf("join-room: watcher3 wants NoRoomFound: %v", err)
}
logger.Infof("join-room: watcher2 ok (no room found)")
logger.Infof("join-room: watcher3 ok (no room found)")

clearEventBuffer(conn)

Expand Down
11 changes: 11 additions & 0 deletions server/game/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"math/big"
"math/rand"
"reflect"
"regexp"
"strings"
"sync"
"time"
Expand All @@ -26,6 +27,8 @@ import (
const (
// RoomID文字列長
lenId = 16

idPattern = "^[0-9a-f]+$"
)

var (
Expand All @@ -34,12 +37,16 @@ var (
roomHistoryInsertQuery string

randsrc *rand.Rand

rerid *regexp.Regexp
)

func init() {
initQueries()
seed, _ := crand.Int(crand.Reader, big.NewInt(math.MaxInt64))
randsrc = rand.New(rand.NewSource(seed.Int64()))

rerid = regexp.MustCompile(idPattern)
}

func dbCols(t reflect.Type) []string {
Expand Down Expand Up @@ -82,6 +89,10 @@ func RandomHex(n int) string {
return hex.EncodeToString(b)
}

func IsValidRoomId(id string) bool {
return rerid.Match([]byte(id))
}

type Repository struct {
hostId uint32

Expand Down
14 changes: 14 additions & 0 deletions server/game/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ func TestQueries(t *testing.T) {
}
}

func TestIsValidRoomId(t *testing.T) {
tests := map[string]bool{
"123456789abcdef": true,
"123456789ABCDEF": false,
"": false,
}

for id, valid := range tests {
if IsValidRoomId(id) != valid {
t.Errorf("IsValidRoomId(%v) wants %v", id, valid)
}
}
}

func newDbMock(t *testing.T) (*sqlx.DB, sqlmock.Sqlmock) {
db, mock, err := sqlmock.New()
if err != nil {
Expand Down
12 changes: 8 additions & 4 deletions server/game/service/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"
"time"

"github.com/go-chi/chi/v5"
"github.com/shiguredo/websocket"
"golang.org/x/xerrors"

Expand Down Expand Up @@ -68,8 +67,8 @@ func (sv *GameService) serveWebSocket(ctx context.Context) <-chan error {
}

ws := &WSHandler{sv}
r := chi.NewMux()
r.Get("/room/{id:[0-9a-f]+}", ws.HandleRoom)
r := http.NewServeMux()
r.HandleFunc("GET /room/{id}", ws.HandleRoom)

sv.wsURLFormat = fmt.Sprintf("%s://%s:%d/room/%%s",
scheme, sv.conf.PublicName, sv.conf.WebsocketPort)
Expand All @@ -87,7 +86,12 @@ func (sv *GameService) serveWebSocket(ctx context.Context) <-chan error {
}

func (s *WSHandler) HandleRoom(w http.ResponseWriter, r *http.Request) {
roomId := chi.URLParam(r, "id")
roomId := r.PathValue("id")
if !game.IsValidRoomId(roomId) {
http.Error(w, "Not Found", http.StatusNotFound)
return
}

appId := r.Header.Get("Wsnet2-App")
clientId := r.Header.Get("Wsnet2-User")
logger := log.GetLoggerWith(
Expand Down
1 change: 0 additions & 1 deletion server/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ go 1.22.0

require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/go-chi/chi/v5 v5.0.12
github.com/go-sql-driver/mysql v1.7.1
github.com/google/go-cmp v0.6.0
github.com/jmoiron/sqlx v1.3.5
Expand Down
2 changes: 0 additions & 2 deletions server/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s=
github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
Expand Down
12 changes: 8 additions & 4 deletions server/hub/service/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"strings"
"time"

"github.com/go-chi/chi/v5"
"github.com/shiguredo/websocket"
"golang.org/x/xerrors"

Expand Down Expand Up @@ -68,8 +67,8 @@ func (sv *HubService) serveWebSocket(ctx context.Context) <-chan error {
}

ws := &WSHandler{sv}
r := chi.NewMux()
r.Get("/room/{id:[0-9a-f]+}", ws.HandleRoom)
r := http.NewServeMux()
r.HandleFunc("GET /room/{id}", ws.HandleRoom)

sv.wsURLFormat = fmt.Sprintf("%s://%s:%d/room/%%s",
scheme, sv.conf.PublicName, sv.conf.WebsocketPort)
Expand All @@ -87,7 +86,12 @@ func (sv *HubService) serveWebSocket(ctx context.Context) <-chan error {
}

func (s *WSHandler) HandleRoom(w http.ResponseWriter, r *http.Request) {
roomId := chi.URLParam(r, "id")
roomId := r.PathValue("id")
if !game.IsValidRoomId(roomId) {
http.Error(w, "Not Found", http.StatusNotFound)
return
}

appId := r.Header.Get("Wsnet2-App")
clientId := r.Header.Get("Wsnet2-User")
logger := log.GetLoggerWith(
Expand Down
Loading

0 comments on commit 1cbc278

Please sign in to comment.