Skip to content

Commit

Permalink
Refactor getServingInfo tests
Browse files Browse the repository at this point in the history
  • Loading branch information
caiocamatta-stripe committed Feb 28, 2024
1 parent acc83c2 commit 2048bb2
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions online/src/test/scala/ai/chronon/online/FetcherBaseTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.mockito.stubbing.Answer
import org.mockito.{Answers, ArgumentCaptor}
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar
import org.junit.Assert.assertSame

import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext, Future}
Expand Down Expand Up @@ -149,41 +148,45 @@ class FetcherBaseTest extends MockitoSugar with Matchers with MockitoHelper {
actualRequest.get.keys shouldBe query.keyMapping.get
}

// updateServingInfo() is called when the batch response is from the KV store.
@Test
def test_getServingInfo_ShouldCallUpdateServingInfoIfBatchResponseIsFromKvStore(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val spiedFetcherBase = spy(baseFetcher)
val oldServingInfo = mock[GroupByServingInfoParsed]
val updatedServingInfo = mock[GroupByServingInfoParsed]
doReturn(updatedServingInfo).when(fetcherBase).updateServingInfo(any(), any())

val batchTimedValuesSuccess = Success(Seq(TimedValue(Array(1.toByte), 2000L)))
val kvStoreBatchResponses = BatchResponses(batchTimedValuesSuccess)
doReturn(updatedServingInfo).when(spiedFetcherBase).updateServingInfo(any(), any())

val result = fetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)

// updateServingInfo is called
val result = spiedFetcherBase.getServingInfo(oldServingInfo, kvStoreBatchResponses)
assertSame(result, updatedServingInfo)
verify(spiedFetcherBase).updateServingInfo(any(), any())
result shouldEqual updatedServingInfo
verify(fetcherBase).updateServingInfo(any(), any())
}

// If a batch response is cached, the serving info should be refreshed. This is needed to prevent
// the serving info from becoming stale if all the requests are cached.
@Test
def test_getServingInfo_ShouldRefreshServingInfoIfBatchResponseIsCached(): Unit = {
val baseFetcher = new FetcherBase(mock[KVStore])
val spiedFetcherBase = spy(baseFetcher)
val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]]
doReturn(ttlCache).when(fetcherBase).getGroupByServingInfo

val oldServingInfo = mock[GroupByServingInfoParsed]
val metaData = mock[MetaData]
doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String])

val metaDataMock = mock[MetaData]
val groupByOpsMock = mock[GroupByOps]
doReturn("test").when(metaDataMock).name
doReturn(metaDataMock).when(groupByOpsMock).metaData
doReturn(groupByOpsMock).when(oldServingInfo).groupByOps

val cachedBatchResponses = BatchResponses(mock[FinalBatchIr])
val ttlCache = mock[TTLCache[String, Try[GroupByServingInfoParsed]]]
doReturn(ttlCache).when(spiedFetcherBase).getGroupByServingInfo
doReturn(Success(oldServingInfo)).when(ttlCache).refresh(any[String])
metaData.name = "test"
groupByOpsMock.metaData = metaData
when(oldServingInfo.groupByOps).thenReturn(groupByOpsMock)
val result = fetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)

// FetcherBase.updateServingInfo is not called, but getGroupByServingInfo.refresh() is.
val result = spiedFetcherBase.getServingInfo(oldServingInfo, cachedBatchResponses)
assertSame(result, oldServingInfo)
result shouldEqual oldServingInfo
verify(ttlCache).refresh(any())
verify(spiedFetcherBase, never()).updateServingInfo(any(), any())
verify(fetcherBase, never()).updateServingInfo(any(), any())
}
}

0 comments on commit 2048bb2

Please sign in to comment.