diff --git a/fakecas.go b/fakecas.go index 449c7f0..4500de9 100644 --- a/fakecas.go +++ b/fakecas.go @@ -4,11 +4,12 @@ import ( "database/sql" "flag" "fmt" - "github.com/labstack/echo" - "github.com/labstack/echo/middleware" - _ "github.com/lib/pq" "html/template" "os" + + "github.com/labstack/echo" + "github.com/labstack/echo/middleware" + _ "github.com/lib/pq" ) var Version string @@ -31,12 +32,12 @@ func main() { })) e.Use(middleware.Recover()) - e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ AllowCredentials: true, - AllowOrigins: []string{"*"}, - AllowMethods: []string{"GET", "PUT", "POST", "DELETE"}, - AllowHeaders: []string{"Range", "Content-Type", "Authorization", "X-Requested-With"}, - ExposeHeaders: []string{"Range", "Content-Type", "Authorization", "X-Requested-With"}, + AllowOrigins: []string{"*"}, + AllowMethods: []string{"GET", "PUT", "POST", "DELETE"}, + AllowHeaders: []string{"Range", "Content-Type", "Authorization", "X-Requested-With"}, + ExposeHeaders: []string{"Range", "Content-Type", "Authorization", "X-Requested-With"}, })) t, err := template.New("login").Parse(LOGINPAGE) @@ -44,7 +45,7 @@ func main() { panic(err) } temp := &Template{templates: t} - e.Renderer = temp + e.Renderer = temp e.GET("/login", LoginGET) e.POST("/login", LoginPOST) @@ -62,5 +63,5 @@ func main() { defer DatabaseConnection.Close() - e.Start(*Host) + e.Start(*Host) } diff --git a/utils.go b/utils.go index faa915d..8ff5bec 100644 --- a/utils.go +++ b/utils.go @@ -1,9 +1,10 @@ package main import ( - "github.com/labstack/echo" "io" "net/url" + + "github.com/labstack/echo" ) func ValidateService(c echo.Context) *url.URL { diff --git a/views.go b/views.go index 2acb768..9f9241a 100644 --- a/views.go +++ b/views.go @@ -3,10 +3,11 @@ package main import ( "database/sql" "fmt" - "github.com/labstack/echo" "net/http" "net/url" "strings" + + "github.com/labstack/echo" ) func LoginPOST(c echo.Context) error { @@ -161,45 +162,79 @@ func ServiceValidate(c echo.Context) error { } func OAuth(c echo.Context) error { - var ( - scopes string - result User - ) tokenId := strings.Replace(c.Request().Header.Get("Authorization"), "Bearer ", "", 1) - var queryString = ` + // Find the user that owns the token + var result User + queryString := ` SELECT DISTINCT osf_guid._id, osf_osfuser.username, osf_osfuser.given_name, - osf_osfuser.family_name, - osf_apioauth2personaltoken.scopes + osf_osfuser.family_name FROM osf_guid LEFT JOIN django_content_type ON django_content_type.model = 'osfuser' JOIN osf_osfuser ON django_content_type.id = osf_guid.content_type_id AND object_id = osf_osfuser.id JOIN osf_apioauth2personaltoken - ON osf_osfuser.id = osf_apioauth2personaltoken.owner_id - WHERE osf_apioauth2personaltoken.token_id = $1 + ON osf_osfuser.id = osf_apioauth2personaltoken.owner_id + WHERE osf_apioauth2personaltoken.token_id = $1 ` - err := DatabaseConnection.QueryRow(queryString, tokenId).Scan(&result.Id, &result.Username, &result.GivenName, &result.FamilyName, &scopes) + err := DatabaseConnection.QueryRow(queryString, tokenId).Scan(&result.Id, &result.Username, &result.GivenName, &result.FamilyName) if err != nil { if err != sql.ErrNoRows { panic(err) } - fmt.Println("Access token", tokenId, "not found") + fmt.Printf("Access token %s not found\n", tokenId) return c.NoContent(http.StatusNotFound) } - fmt.Println("User found for token: username =", result.Username, ", guid =", result.Id) + fmt.Printf("User found for token: username = %s , guid =%s\n", result.Username, result.Id) + + // Find all the scopes associated with the token + fmt.Printf("Reading scopes ... ") + queryString = ` + SELECT DISTINCT osf_apioauth2scope.name + FROM osf_apioauth2personaltoken_scopes + JOIN osf_apioauth2personaltoken + on osf_apioauth2personaltoken_scopes.apioauth2personaltoken_id = osf_apioauth2personaltoken.id + JOIN osf_apioauth2scope + on osf_apioauth2personaltoken_scopes.apioauth2scope_id = osf_apioauth2scope.id + WHERE osf_apioauth2personaltoken.token_id = $1 + ` + rows, err := DatabaseConnection.Query(queryString, tokenId) + if err != nil { + if err != sql.ErrNoRows { + panic(err) + } + fmt.Printf("No scope is found for access token %s\n", tokenId) + return c.NoContent(http.StatusNotFound) + } + defer rows.Close() + scopes := make([]string, 0) + var scope string + for rows.Next() { + err = rows.Scan(&scope) + if err != nil { + panic(err) + } + fmt.Printf("%s, ", scope) + scopes = append(scopes, scope) + } + err = rows.Err() + if err != nil { + panic(err) + } + fmt.Printf("... %d scopes in total.\n", len(scopes)) + // Return 200 with user information and scopes return c.JSON(200, OAuthResponse{ Id: result.Id, Attributes: OAuthAttributes{ LastName: result.FamilyName, FirstName: result.GivenName, }, - Scope: strings.Split(scopes, " "), + Scope: scopes, }) }