diff --git a/types.go b/types.go index efa3dd5..debe846 100644 --- a/types.go +++ b/types.go @@ -17,14 +17,13 @@ type OAuthResponse struct { } type User struct { - Id string `bson:"_id"` - Username string `bson:"username"` - Emails []string `bson:"emails"` - Fullname string `bson:"fullname"` - GivenName string `bson:"given_name"` - FamilyName string `bson:"family_name"` - IsRegistered bool `bson:"is_registered"` - VerificationKey string `bson:"verification_key"` + Id string `bson:"_id"` + Username string `bson:"username"` + Fullname string `bson:"fullname"` + GivenName string `bson:"given_name"` + FamilyName string `bson:"family_name"` + IsRegistered bool `bson:"is_registered"` + VerificationKey string `bson:"verification_key"` } type ServiceResponse struct { diff --git a/views.go b/views.go index 69ed3e6..9d5fad0 100644 --- a/views.go +++ b/views.go @@ -22,7 +22,12 @@ func LoginPOST(c echo.Context) error { username := strings.ToLower(strings.TrimSpace(c.FormValue("username"))) // fakeCAS does not check password - err := DatabaseConnection.QueryRow("SELECT is_registered FROM osf_osfuser WHERE username = $1 OR $1 = ANY(emails)", username).Scan(&isRegistered) + err := DatabaseConnection.QueryRow(` + SELECT is_registered + FROM osf_osfuser + WHERE username = $1 + OR EXISTS(SELECT * FROM osf_email WHERE osf_email.user_id = osf_osfuser.id AND osf_email.address = $1) + `, username).Scan(&isRegistered) if err != nil { if err != sql.ErrNoRows { @@ -77,7 +82,12 @@ func LoginGET(c echo.Context) error { var verification string uname := strings.ToLower(strings.TrimSpace(c.FormValue("username"))) - err = DatabaseConnection.QueryRow("SELECT verification_key FROM osf_osfuser WHERE username = $1 OR $1 = ANY(emails)", uname).Scan(&verification) + err = DatabaseConnection.QueryRow(` + SELECT verification_key + FROM osf_osfuser + WHERE username = $1 + OR EXISTS(SELECT * FROM osf_email WHERE osf_email.user_id = osf_osfuser.id AND osf_email.address = $1) + `, username).Scan(&verification) if err != nil { if err != sql.ErrNoRows { @@ -118,12 +128,14 @@ func ServiceValidate(c echo.Context) error { osf_osfuser.username, osf_osfuser.given_name, osf_osfuser.family_name - From osf_guid - LEFT JOIN django_content_type - ON django_content_type.model = 'osfuser' - JOIN osf_osfuser - ON django_content_type.id = osf_guid.content_type_id AND object_id = osf_osfuser.id - WHERE username = $1 OR $1 = ANY(emails) + FROM osf_osfuser + RIGHT JOIN osf_guid + ON osf_guid.object_id = osf_osfuser.id + WHERE osf_guid.content_type_id = (SELECT id FROM django_content_type WHERE model = 'osfuser' LIMIT 1) + AND ( + EXISTS (SELECT * FROM osf_email WHERE osf_email.user_id = osf_osfuser.id AND osf_email.address = $1) + OR (osf_osfuser.username = $1) + ); ` err := DatabaseConnection.QueryRow(queryString, ticket).Scan(&result.Id, &result.Username, &result.GivenName, &result.FamilyName) if err != nil { @@ -133,7 +145,7 @@ func ServiceValidate(c echo.Context) error { fmt.Println("User", ticket, "not found.") return c.NoContent(http.StatusNotFound) } - fmt.Println("User found: username =", result.Username, ", guid =", result.Id,) + fmt.Println("User found: username =", result.Username, ", guid =", result.Id) response := ServiceResponse{ Xmlns: "http://www.yale.edu/tp/cas", @@ -180,7 +192,7 @@ func OAuth(c echo.Context) error { fmt.Println("Access token", tokenId, "not found") return c.NoContent(http.StatusNotFound) } - fmt.Println("User found for token: username =", result.Username, ", guid =", result.Id,) + fmt.Println("User found for token: username =", result.Username, ", guid =", result.Id) return c.JSON(200, OAuthResponse{ Id: result.Id,