Skip to content

Commit

Permalink
[psychlab] Additional level: Harlow task
Browse files Browse the repository at this point in the history
  • Loading branch information
DeepMind Lab Team authored and tkoeppe committed Feb 4, 2019
1 parent 6ab68c3 commit 42a1d8a
Show file tree
Hide file tree
Showing 3 changed files with 351 additions and 0 deletions.
6 changes: 6 additions & 0 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

## Next Release

### New Levels:

1. Psychlab.

1. contributed/psychlab/harlow

### Minor Improvements:

1. Improve documentation of how to configure non-hermetic dependencies (Lua,
Expand Down
322 changes: 322 additions & 0 deletions game_scripts/levels/contributed/psychlab/factories/harlow_factory.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
--[[ Copyright (C) 2019 Google Inc.
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
]]

local brady_konkle_oliva2008 = require 'datasets.brady_konkle_oliva2008'
local game = require 'dmlab.system.game'
local log = require 'common.log'
local helpers = require 'common.helpers'
local random = require 'common.random'
local psychlab_factory = require 'factories.psychlab.factory'
local psychlab_helpers = require 'factories.psychlab.helpers'
local point_and_click = require 'factories.psychlab.point_and_click'
local random = require 'common.random'
local tensor = require 'dmlab.system.tensor'

--[[ The Harlow task from:
Harlow, H. F. (1949). The formation of learning sets. Psychological review,
56(1), 51.
As implemented in:
Wang, J. X., Kurth-Nelson, Z., Tirumala, D., Soyer, H., Leibo, J. Z., Munos, R.,
... & Botvinick, M. (2016). Learning to reinforcement learn. arXiv preprint
arXiv:1611.05763.
and
Wang, J. X., Kurth-Nelson, Z., Kumaran, D., Tirumala, D., Soyer, H., Leibo,
J. Z., ... & Botvinick, M. (2018). Prefrontal cortex as a meta-reinforcement
learning system. Nature neuroscience, 21(6), 860.
Two random images (selected at the beginning of the episode)
are shown per episode for N (default 6) trials.
One image yields positive rewards if agent saccades to it, other
yields negative rewards. Between trials, agent saccades to a red
fixation cross to start the trial.
Images can be placed on either the left or right side of the
screen, determined randomly at the beginning of the trial.
Any images can be used, but should be placed in /data and numbered
according to the template '%04d.png'. There should be at least
TRAIN_BATCH + TEST_BATCH images. Set IS_TRAIN = false for
evaluation images (held out from training).
]]

local TIME_TO_FIXATE_CROSS = 4 -- in frames
local TIME_TO_FIXATE_TARGET = 4 -- in frames
local FAST_INTER_TRIAL_INTERVAL = 5 -- in frames
local SCREEN_SIZE = {width = 512, height = 512}
local BG_COLOR = {0, 0, 0}
local TRIALS_PER_EPISODE_CAP = 6

local TARGET_SIZE = 0.4

local FIXATION_REWARD = 1
local CORRECT_REWARD = 5
local INCORRECT_REWARD = -5

local FIXATION_SIZE = 0.1
local FIXATION_COLOR = {255, 0, 0} -- Red
local CENTER = {0.5, 0.5}
local LEFT = {.25, .5}
local RIGHT = {.75, .5}

local TRAIN_BATCH = 1000
local TEST_BATCH = 1000
local IS_TRAIN = true

local factory = {}

function factory.createLevelApi(kwargs)
kwargs.timeToFixateCross = kwargs.timeToFixateCross or TIME_TO_FIXATE_CROSS
kwargs.timeToFixateTarget = kwargs.timeToFixateTarget or TIME_TO_FIXATE_TARGET
kwargs.fastInterTrialInterval = kwargs.fastInterTrialInterval or
FAST_INTER_TRIAL_INTERVAL
kwargs.screenSize = kwargs.screenSize or SCREEN_SIZE
kwargs.bgColor = kwargs.bgColor or BG_COLOR
kwargs.trialsPerEpisodeCap = kwargs.trialsPerEpisodeCap or
TRIALS_PER_EPISODE_CAP
kwargs.targetSize = kwargs.targetSize or TARGET_SIZE
kwargs.fixationReward = kwargs.fixationReward or FIXATION_REWARD
kwargs.correctReward = kwargs.correctReward or CORRECT_REWARD
kwargs.incorrectReward = kwargs.incorrectReward or INCORRECT_REWARD
kwargs.fixationSize = kwargs.fixationSize or FIXATION_SIZE
kwargs.fixationColor = kwargs.fixationColor or FIXATION_COLOR
kwargs.center = kwargs.center or CENTER
kwargs.left = kwargs.left or LEFT
kwargs.right = kwargs.right or RIGHT
kwargs.trainBatch = kwargs.trainBatch or TRAIN_BATCH
kwargs.testBatch = kwargs.testBatch or TEST_BATCH
kwargs.isTrain = kwargs.isTrain or IS_TRAIN

local ARG = {
screenSize = kwargs.screenSize,
}

-- Class definition for harlow psychlab environment.
local env = {}
env.__index = env

setmetatable(env, {
__call = function (cls, ...)
local self = setmetatable({}, cls)
self:_init(...)
return self
end
})

-- 'init' gets called at the start of each episode.
function env:_init(pac, opts)
log.info('opts passed to _init:\n' .. helpers.tostring(opts))
log.info('ARG passed to _init:\n' .. helpers.tostring(ARG))

if self.dataset == nil then
self.dataset = brady_konkle_oliva2008(opts)
end

self.screenSize = kwargs.screenSize
self.numTrials = kwargs.trialsPerEpisodeCap

self.trainBatch = kwargs.trainBatch
self.testBatch = kwargs.testBatch
self.isTrain = kwargs.isTrain

self:setupImages()
-- Store a copy of the 'point_and_click' api.
self.pac = pac
end

--[[ Reset is called after init. It is called only once per episode.
Note: the episodeId passed to this function may not be correct if the job
has resumed from a checkpoint after preemption.
]]
function env:reset(episodeId, seed)
random:seed(seed)

self.pac:setBackgroundColor(kwargs.bgColor)
self.pac:clearWidgets()
psychlab_helpers.addFixation(self, kwargs.fixationSize)

self.currentTrial = {}
self._previousTrialId = 0
self.trialNum = 0

psychlab_helpers.setTrialsPerEpisodeCap(self, kwargs.trialsPerEpisodeCap)

local targetPositions = {}
for i = 1, self.numTrials / 2 do
table.insert(targetPositions, 1)
end
for i = self.numTrials / 2 + 1, self.numTrials do
table.insert(targetPositions, 2)
end
local batch_offset
if self.isTrain then
self._numBatch = self.trainBatch
batch_offset = 0
else
self._numBatch = self.testBatch
batch_offset = self.trainBatch
end
self._targetPositions = random:shuffle(targetPositions)
-- randomly pick a correctID and an incorrectID
self.correctID = random:uniformInt(1, self._numBatch) + batch_offset
self.incorrectID = random:uniformInt(1, self._numBatch - 1) + batch_offset
if self.incorrectID >= self.correctID then
self.incorrectID = self.incorrectID + 1
end
log.info('Episode set up, corrID=' .. self.correctID ..
', incorrID=' .. self.incorrectID)
end

function env:setupImages()
self.images = {}

self.images.fixation = psychlab_helpers.getFixationImage(self.screenSize,
kwargs.bgColor, kwargs.fixationColor, kwargs.fixationSize)

self.target = tensor.ByteTensor(256, 256, 3)
end

function env:finishTrial(delay)
self.currentTrial.reactionTime =
game:episodeTimeSeconds() - self._currentTrialStartTime

self.currentTrial.stepCount = self.pac:elapsedSteps()
psychlab_helpers.publishTrialData(self.currentTrial, kwargs.schema)
psychlab_helpers.finishTrialCommon(self, delay, kwargs.fixationSize)
end

function env:fixationCallback(name, mousePos, hoverTime, userData)
if hoverTime == kwargs.timeToFixateCross then
self.pac:addReward(kwargs.fixationReward)
self.pac:removeWidget('fixation')
self.pac:removeWidget('center_of_fixation')

self.currentTrial.trialId = self._previousTrialId + 1
-- Measure reaction time since the trial started.
self._currentTrialStartTime = game:episodeTimeSeconds()
self.pac:resetSteps()

self:startTrial()
end
end

function env:correctResponseCallback(name, mousePos, hoverTime, userData)
if hoverTime == kwargs.timeToFixateTarget then
self.currentTrial.response = name
self.currentTrial.correct = 1
self.pac:addReward(kwargs.correctReward)
self:finishTrial(kwargs.fastInterTrialInterval)
end
end

function env:incorrectResponseCallback(name, mousePos, hoverTime, userData)
if hoverTime == kwargs.timeToFixateTarget then
self.currentTrial.response = name
self.currentTrial.correct = 0
self.pac:addReward(kwargs.incorrectReward)
self:finishTrial(kwargs.fastInterTrialInterval)
end
end

function env:startTrial()
self.trialNum = self.trialNum + 1
-- Fixation initiates the next trial
self.currentTrial.trial = self.trialNum
self._targetPos = self._targetPositions[self.trialNum]
self:addObjects(self.correctID, self.incorrectID, self._targetPos)
end

function env:addTargetImage(targetImage, targetSize, targetPos, targetName)
assert(targetImage, 'targetImage must not be nil')
assert(targetSize, 'targetSize must not be nil')
assert(targetSize > 0 and targetSize <= 1, 'targetSize must in (0, 1]')
self.target:copy(targetImage)
local sizeInPixels = psychlab_helpers.getSizeInPixels(self.screenSize,
targetSize)
local scaledImage = psychlab_helpers.scaleImage(self.target,
sizeInPixels.width,
sizeInPixels.height)

self.pac:addWidget{
name = targetName,
image = scaledImage,
pos = targetPos,
size = {targetSize, targetSize},
}
end

function env:addObjects(correctID, incorrectID, targetPos)
local posCorrect, posIncorrect, posCorrectButton, posIncorrectButton
self.currentTrial.correctResponse = targetPos
local upperLeft = psychlab_helpers.getUpperLeftFromCenter
if targetPos == 1 then
posCorrect = upperLeft(kwargs.left, kwargs.targetSize)
posIncorrect = upperLeft(kwargs.right, kwargs.targetSize)
posCorrectButton = upperLeft(kwargs.left, kwargs.targetSize / 2)
posIncorrectButton = upperLeft(kwargs.right, kwargs.targetSize / 2)
elseif targetPos == 2 then
posCorrect = upperLeft(kwargs.right, kwargs.targetSize)
posIncorrect = upperLeft(kwargs.left, kwargs.targetSize)
posCorrectButton = upperLeft(kwargs.right, kwargs.targetSize / 2)
posIncorrectButton = upperLeft(kwargs.left, kwargs.targetSize / 2)
else
error('targetPos not valid')
end
self:addTargetImage(self.dataset:getImage(correctID),
kwargs.targetSize,
posCorrect,
'correct_image')
self:addTargetImage(self.dataset:getImage(incorrectID),
kwargs.targetSize,
posIncorrect,
'incorrect_image')

-- Add buttons, which are invisible and colocated with the images
self.pac:addWidget{
name = 'correct_image_button',
image = nil,
pos = posCorrectButton,
size = {kwargs.targetSize / 2, kwargs.targetSize / 2},
mouseHoverCallback = self.correctResponseCallback,
}
self.pac:addWidget{
name = 'incorrect_image_button',
image = nil,
pos = posIncorrectButton,
size = {kwargs.targetSize / 2, kwargs.targetSize / 2},
mouseHoverCallback = self.incorrectResponseCallback,
}
end

function env:removeArray()
self.pac:removeWidget('correct_image')
self.pac:removeWidget('incorrect_image')
self.pac:removeWidget('correct_image_button')
self.pac:removeWidget('incorrect_image_button')
end

return psychlab_factory.createLevelApi{
env = point_and_click,
envOpts = {environment = env, screenSize = ARG.screenSize}
}
end

return factory
23 changes: 23 additions & 0 deletions game_scripts/levels/contributed/psychlab/harlow.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
--[[ Copyright (C) 2019 Google Inc.
This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
]]

local factory = require 'levels.contributed.psychlab.factories.harlow_factory'

return factory.createLevelApi{
schema = 'harlow',
episodeLengthSeconds = 300,
}

0 comments on commit 42a1d8a

Please sign in to comment.