Skip to content

Commit

Permalink
Merge pull request #2597 from NationalSecurityAgency/2.12.X-merge3
Browse files Browse the repository at this point in the history
2.12.x merge3
  • Loading branch information
sudo-may authored Jun 20, 2024
2 parents d948fdd + 8df1c12 commit 0c5079f
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 36 deletions.
50 changes: 31 additions & 19 deletions e2e-tests/cypress/support/commands.js
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,26 @@ Cypress.Commands.add("saveVideoAttrs", (projNum, skillNum, videoAttrs) => {
formData.set('isAlreadyHosted', videoAttrs.isAlreadyHosted);
}
let requestDone = false;
if (videoAttrs.file) {
const fileType = videoAttrs.file.endsWith('mp4') ? 'video/mp4' : 'video/webm';
cy.readFile(`cypress/fixtures/${videoAttrs.file}`, 'binary')
.then((binaryFile) => {
const blob = Cypress.Blob.binaryStringToBlob(binaryFile, fileType);
formData.set('file', blob, videoAttrs.file);
cy.request('POST', url, formData).then(() => {
requestDone = true;
});
cy.getCookie('XSRF-TOKEN').should('exist').then((xsrfCookie) => {
if (videoAttrs.file) {
const fileType = videoAttrs.file.endsWith('mp4') ? 'video/mp4' : 'video/webm';
cy.readFile(`cypress/fixtures/${videoAttrs.file}`, 'binary')
.then((binaryFile) => {
const blob = Cypress.Blob.binaryStringToBlob(binaryFile, fileType);
formData.set('file', blob, videoAttrs.file);
cy.request({ method: 'POST', url, body: formData, headers: {'X-XSRF-TOKEN': xsrfCookie.value} }).then(() => {
requestDone = true;
});
});
} else {
cy.request({ method: 'POST', url, body: formData, headers: {'X-XSRF-TOKEN': xsrfCookie.value} }).then(() => {
requestDone = true;
});
} else {
cy.request('POST', url, formData).then(() => {
requestDone = true;
});
}
}

cy.waitUntil(() => requestDone, {
timeout: 30000, // waits up to 30 seconds, default is 5 seconds
cy.waitUntil(() => requestDone, {
timeout: 30000, // waits up to 30 seconds, default is 5 seconds
});
});
});

Expand Down Expand Up @@ -296,6 +298,9 @@ Cypress.Commands.add("login", (user, pass) => {
timeout: 30000, // waits up to 30 seconds, default is 5 seconds
});
cy.log(`Logged in as [${user}] with [${pass}]`);
cy.request('/app/userInfo').then((response) => {
cy.wrap(response.body).as('userInfo');
})
});

Cypress.Commands.add("resetEmail", () => {
Expand Down Expand Up @@ -1289,12 +1294,19 @@ Cypress.Commands.add("validateElementsOrder", (selector, containsValues) => {
}
});

Cypress.Commands.add('formRequest', (method, url, formData, onComplete) => {
Cypress.Commands.add('formRequest', (method, url, formData, onComplete, includeXSRF = false) => {
const xhr = new XMLHttpRequest();
xhr.open(method, url)
xhr.onload = function () { onComplete(xhr) }
xhr.onerror = function () { onComplete(xhr) }
xhr.send(formData)
if (includeXSRF) {
cy.getCookie('XSRF-TOKEN').should('exist').then((xsrfCookie) => {
xhr.setRequestHeader('X-XSRF-TOKEN', xsrfCookie.value)
xhr.send(formData)
});
} else {
xhr.send(formData)
}
})

Cypress.Commands.add("uploadCustomIcon", (fileName, url) => {
Expand All @@ -1309,7 +1321,7 @@ Cypress.Commands.add("uploadCustomIcon", (fileName, url) => {
expect(response.status)
.to
.eq(200);
});
}, true);
});
});

Expand Down
94 changes: 92 additions & 2 deletions service/src/main/java/skills/auth/PortalWebSecurityHelper.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
*/
package skills.auth

import jakarta.servlet.FilterChain
import jakarta.servlet.ServletException
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.beans.factory.annotation.Value
import org.springframework.context.annotation.DependsOn
Expand All @@ -24,11 +28,24 @@ import org.springframework.security.authorization.AuthorizationManager
import org.springframework.security.authorization.AuthorizationManagers
import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.web.access.intercept.RequestAuthorizationContext
import org.springframework.security.web.authentication.www.BasicAuthenticationFilter
import org.springframework.security.web.csrf.CookieCsrfTokenRepository
import org.springframework.security.web.csrf.CsrfToken
import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler
import org.springframework.security.web.csrf.CsrfTokenRequestHandler
import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler
import org.springframework.security.web.util.matcher.AntPathRequestMatcher
import org.springframework.security.web.util.matcher.OrRequestMatcher
import org.springframework.security.web.util.matcher.RequestMatcher
import org.springframework.stereotype.Component
import org.springframework.util.StringUtils
import org.springframework.web.filter.OncePerRequestFilter
import skills.auth.inviteOnly.InviteOnlyProjectAuthorizationManager
import skills.auth.userCommunity.UserCommunityAuthorizationManager
import skills.storage.model.auth.RoleName

import java.util.function.Supplier

@Component
@DependsOn(['inviteOnlyProjectAuthorizationManager', 'userCommunityAuthorizationManager'])
class PortalWebSecurityHelper {
Expand All @@ -45,15 +62,25 @@ class PortalWebSecurityHelper {
@Value('#{"${management.endpoints.web.path-mapping.prometheus:prometheus}"}')
String prometheusPath

@Value('#{"${skills.config.disableCsrfProtection:false}"}')
Boolean disableCsrfProtection

@Autowired
InviteOnlyProjectAuthorizationManager inviteOnlyProjectAuthorizationManager

@Autowired
UserCommunityAuthorizationManager userCommunityAuthorizationManager

HttpSecurity configureHttpSecurity(HttpSecurity http) {

http.csrf().disable()
if (disableCsrfProtection) {
http.csrf().disable()
} else {
http.csrf((csrf) -> csrf
.requireCsrfProtectionMatcher(new MultipartRequestMatcher())
.csrfTokenRepository(CookieCsrfTokenRepository.withHttpOnlyFalse())
.csrfTokenRequestHandler(new SpaCsrfTokenRequestHandler()))
.addFilterAfter(new CsrfCookieFilter(), BasicAuthenticationFilter.class)
}

if (publiclyExposePrometheusMetrics) {
http.authorizeHttpRequests().requestMatchers(HttpMethod.GET, "${managementPath}/${prometheusPath}").permitAll()
Expand Down Expand Up @@ -93,3 +120,66 @@ class PortalWebSecurityHelper {
return AuthorizationManagers.allOf(([AuthorityAuthorizationManager.hasAnyAuthority(authorities)] + managers) as AuthorizationManager[])
}
}

final class SpaCsrfTokenRequestHandler extends CsrfTokenRequestAttributeHandler {
private final CsrfTokenRequestHandler delegate = new XorCsrfTokenRequestAttributeHandler()

@Override
void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken) {
/*
* Always use XorCsrfTokenRequestAttributeHandler to provide BREACH protection of
* the CsrfToken when it is rendered in the response body.
*/
this.delegate.handle(request, response, csrfToken)
}

@Override
String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
/*
* If the request contains a request header, use CsrfTokenRequestAttributeHandler
* to resolve the CsrfToken. This applies when a single-page application includes
* the header value automatically, which was obtained via a cookie containing the
* raw CsrfToken.
*/
if (StringUtils.hasText(request.getHeader(csrfToken.getHeaderName()))) {
return super.resolveCsrfTokenValue(request, csrfToken)
}
/*
* In all other cases (e.g. if the request contains a request parameter), use
* XorCsrfTokenRequestAttributeHandler to resolve the CsrfToken. This applies
* when a server-side rendered form includes the _csrf request parameter as a
* hidden input.
*/
return this.delegate.resolveCsrfTokenValue(request, csrfToken)
}
}

final class CsrfCookieFilter extends OncePerRequestFilter {

@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
CsrfToken csrfToken = (CsrfToken) request.getAttribute("_csrf")
// Render the token value to a cookie by causing the deferred token to be loaded
csrfToken.getToken()

filterChain.doFilter(request, response)
}
}

final class MultipartRequestMatcher implements RequestMatcher {

private final HashSet<String> allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"))
private final OrRequestMatcher pathMatcher = new OrRequestMatcher(
new AntPathRequestMatcher("/api/upload"),
new AntPathRequestMatcher("/admin/projects/*/icons/upload"),
new AntPathRequestMatcher("/supervisor/icons/upload"),
new AntPathRequestMatcher("/admin/projects/*/skills/*/video"),
)

@Override
boolean matches(HttpServletRequest request) {
Boolean matches = (pathMatcher.matches(request) && !this.allowedMethods.contains(request.getMethod()))
return matches
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class CustomValidator {
private static final Pattern HTML = ~/(?s)<[\/]?\w+(?: .+?)*>/
private static final Pattern CODEBLOCK = ~/(?ms)(^[`]{3}$.*?^[`]{3}$)/

private static final Pattern TABLE_FIX = ~/(?m)(^\n)(^[|].+[|]$\n^[|].*[-]{3,}.*[|]$)/
private static final Pattern TABLE_FIX = ~/(?ms)(^\n)(^[|].+[|]$\n^[|].*[-]{3,}.*[|]$)/
private static final Pattern CODEBLOCK_FIX = ~/(?m)(^\n)(^[`]{3}$)/
private static final Pattern LIST_FIX = ~/(?m)(^\n)(\s*\d\. |\* |- .*$)/

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ import org.apache.commons.lang3.time.DurationFormatUtils
import org.springframework.http.HttpEntity
import org.springframework.http.HttpHeaders
import org.springframework.http.MediaType
import org.springframework.http.ResponseEntity
import org.springframework.util.LinkedMultiValueMap
import org.springframework.util.MultiValueMap
import org.springframework.web.client.RestTemplate
import org.springframework.web.client.HttpClientErrorException
import skills.intTests.utils.DefaultIntSpec
import skills.intTests.utils.EmailUtils
import skills.intTests.utils.RestTemplateWrapper
import skills.intTests.utils.SkillsService
import skills.utils.WaitFor
import spock.lang.IgnoreIf
Expand All @@ -39,6 +41,7 @@ class PasswordResetSpec extends DefaultIntSpec {

GreenMail greenMail = new GreenMail(ServerSetupTest.SMTP)
SkillsService rootSkillsService
RestTemplate template = new RestTemplate()

def setup() {
greenMail.start()
Expand All @@ -57,6 +60,8 @@ class PasswordResetSpec extends DefaultIntSpec {
"publicUrl" : "http://localhost:${localPort}/".toString(),
"fromEmail" : "resetspec@skilltreetests"
])
template.interceptors.add(new RestTemplateWrapper.StatefulRestTemplateInterceptor())
template.getForEntity("http://localhost:${localPort}/app/users/validExistingDashboardUserId/[email protected]", String.class)
}

def cleanup(){
Expand All @@ -69,7 +74,6 @@ class PasswordResetSpec extends DefaultIntSpec {

when:
//post request with an unauthenticated client to ensure that the url is publicly available
RestTemplate template = new RestTemplate()
HttpHeaders headers = new HttpHeaders()
headers.setContentType(MediaType.MULTIPART_FORM_DATA)
MultiValueMap body = new LinkedMultiValueMap<>()
Expand All @@ -95,7 +99,6 @@ class PasswordResetSpec extends DefaultIntSpec {
def "reset password with token from email"() {
SkillsService aUser = createService("[email protected]", "somepassword")
//post request with an unauthenticated client to ensure that the url is publicly available
RestTemplate template = new RestTemplate()
HttpHeaders headers = new HttpHeaders()
headers.setContentType(MediaType.MULTIPART_FORM_DATA)
MultiValueMap body = new LinkedMultiValueMap<>()
Expand Down Expand Up @@ -131,7 +134,6 @@ class PasswordResetSpec extends DefaultIntSpec {
def "reset password with invalid token fails"() {
SkillsService aUser = createService("[email protected]", "somepassword")
//post request with an unauthenticated client to ensure that the url is publicly available
RestTemplate template = new RestTemplate()
HttpHeaders headers = new HttpHeaders()
headers.setContentType(MediaType.MULTIPART_FORM_DATA)
MultiValueMap body = new LinkedMultiValueMap<>()
Expand Down Expand Up @@ -164,7 +166,6 @@ class PasswordResetSpec extends DefaultIntSpec {

SkillsService aUser = createService("[email protected]", "somepassword")
//post request with an unauthenticated client to ensure that the url is publicly available
RestTemplate template = new RestTemplate()
HttpHeaders headers = new HttpHeaders()
headers.setContentType(MediaType.MULTIPART_FORM_DATA)
MultiValueMap body = new LinkedMultiValueMap<>()
Expand Down
6 changes: 5 additions & 1 deletion service/src/test/java/skills/intTests/WebsocketSpecs.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.springframework.web.socket.sockjs.client.SockJsClient
import org.springframework.web.socket.sockjs.client.Transport
import org.springframework.web.socket.sockjs.client.WebSocketTransport
import skills.intTests.utils.DefaultIntSpec
import skills.intTests.utils.RestTemplateWrapper
import skills.intTests.utils.SkillsFactory
import skills.intTests.utils.TestUtils
import skills.services.events.CompletionItem
Expand Down Expand Up @@ -303,6 +304,7 @@ class WebsocketSpecs extends DefaultIntSpec {
} else {
xhrTransport = new RestTemplateXhrTransport()
}
((RestTemplate) xhrTransport.getRestTemplate()).interceptors.add(new RestTemplateWrapper.StatefulRestTemplateInterceptor())
xhrTransport.xhrStreamingDisabled = xhrPolling
transports.add(xhrTransport)
} else {
Expand Down Expand Up @@ -471,6 +473,8 @@ class WebsocketSpecs extends DefaultIntSpec {
.setConnectionManager(connectionManager)
.build()
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(client)
return new RestTemplate(requestFactory)
RestTemplate restTemplate = new RestTemplate(requestFactory)
restTemplate.getInterceptors().add(new RestTemplateWrapper.StatefulRestTemplateInterceptor())
return restTemplate
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,36 @@ class RestTemplateWrapper extends RestTemplate {
* Need for load balancer support as it uses cookies to keep track which server currently connected to
*/
static class StatefulRestTemplateInterceptor implements ClientHttpRequestInterceptor {
private String cookie;
private List<String> cookies;
private String xsrfToken;

@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
if (cookie != null) {
request.getHeaders().add(HttpHeaders.COOKIE, cookie);

HttpHeaders requstHeaders = request.getHeaders()
if (cookies) {
requstHeaders.addAll(HttpHeaders.COOKIE, cookies);
}
if (xsrfToken != null) {
requstHeaders.add("X-XSRF-TOKEN" , xsrfToken);
}
log.debug("REQUEST: [{}], headers [{}]", request.URI, request.headers)
ClientHttpResponse response = execution.execute(request, body);

if (cookie == null) {
cookie = response.getHeaders().getFirst(HttpHeaders.SET_COOKIE);
log.debug("Setting cookie to [{}]", cookie)
HttpHeaders headers = response.getHeaders();

List<String> returnedCookies = headers.getOrEmpty(HttpHeaders.SET_COOKIE)
if (returnedCookies && cookies == null) {
cookies = returnedCookies
log.info("Setting cookies to {}", returnedCookies)
printf "Setting cookies to ${returnedCookies}"
}
if (returnedCookies && !xsrfToken) {
String cookieXSRF = returnedCookies.find { it.startsWith("XSRF-TOKEN=") }
if (cookieXSRF) {
xsrfToken = (cookieXSRF =~ /XSRF-TOKEN=([^;]*)/)[0][1]
log.debug("Response: [{}], set xsrfToken to [{}]", request.URI, xsrfToken)
}
}
return response;
}
Expand Down Expand Up @@ -177,7 +195,7 @@ class RestTemplateWrapper extends RestTemplate {

ResponseEntity response = putForEntity(skillsServiceUrl + '/createAccount', userInfo)
if ( response.statusCode != HttpStatus.OK) {
throw new SkillsClientException(response.body, skillsServiceUrl, response.statusCode)
throw new SkillsClientException((String)response.body, skillsServiceUrl, (HttpStatus)response.statusCode)
}
accountCreated = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SkillsClientException extends RuntimeException {
def res

SkillsClientException(String message, String url, HttpStatus httpStatus) {
super(message)
super((String)message)
this.url = url
this.httpStatus = httpStatus
}
Expand Down
5 changes: 4 additions & 1 deletion service/src/test/java/skills/intTests/utils/WSHelper.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.springframework.core.io.ClassPathResource
import org.springframework.core.io.FileSystemResource
import org.springframework.core.io.Resource
import org.springframework.http.*
import org.springframework.http.client.ClientHttpRequestInterceptor
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory
import org.springframework.http.client.support.BasicAuthenticationInterceptor
import org.springframework.util.LinkedMultiValueMap
Expand Down Expand Up @@ -101,7 +102,9 @@ class WSHelper {
}

void setProxyCredentials(String clientId, String secretCode) {
oAuthRestTemplate.setInterceptors([new BasicAuthenticationInterceptor(clientId, secretCode)])
List<ClientHttpRequestInterceptor> existingInterceptors = oAuthRestTemplate.getInterceptors()
existingInterceptors.add(new BasicAuthenticationInterceptor(clientId, secretCode))
oAuthRestTemplate.setInterceptors(existingInterceptors)
}

def appPut(String endpoint, def params) {
Expand Down
Loading

0 comments on commit 0c5079f

Please sign in to comment.