Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

randomize initial zone selection #583

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
169 changes: 87 additions & 82 deletions backend/gce.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
mathrand "math/rand"
"net/http"
"net/url"
"regexp"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -48,7 +49,6 @@ import (
)

const (
defaultGCEZone = "us-central1-a"
defaultGCEMachineType = "n1-standard-2"
defaultGCEPremiumMachineType = "n1-standard-4"
defaultGCENetwork = "default"
Expand All @@ -59,7 +59,6 @@ const (
defaultGCEStopPollSleep = 3 * time.Second
defaultGCEStopPrePollSleep = 15 * time.Second
defaultGCESubnet = "default"
defaultGCERegion = "us-central1"
defaultGCEUploadRetries = uint64(120)
defaultGCEUploadRetrySleep = 1 * time.Second
defaultGCEImageSelectorType = "env"
Expand Down Expand Up @@ -92,10 +91,11 @@ var (
"IMAGE_SELECTOR_URL": "URL for image selector API, used only when image selector is \"api\"",
"IMAGE_[ALIAS_]{ALIAS}": "full name for a given alias given via IMAGE_ALIASES, where the alias form in the key is uppercased and normalized by replacing non-alphanumerics with _",
"MACHINE_TYPE": fmt.Sprintf("machine name (default %q)", defaultGCEMachineType),
"MINIMUM_CPU_PLATFORM": "minimum cpu platform",
"NETWORK": fmt.Sprintf("network name (default %q)", defaultGCENetwork),
"PREEMPTIBLE": "boot job instances with preemptible flag enabled (default false)",
"PREMIUM_MACHINE_TYPE": fmt.Sprintf("premium machine type (default %q)", defaultGCEPremiumMachineType),
"PROJECT_ID": "[REQUIRED] GCE project id",
"PROJECT_ID": "[REQUIRED] GCE project id (will try to auto detect it if not set)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, [REQUIRED] can be dropped

"PUBLIC_IP": "boot job instances with a public ip, disable this for NAT (default true)",
"PUBLIC_IP_CONNECT": "connect to the public ip of the instance instead of the internal, only takes effect if PUBLIC_IP is true (default true)",
"IMAGE_PROJECT_ID": "GCE project id to use for images, will use PROJECT_ID if not specified",
Expand All @@ -107,8 +107,8 @@ var (
"RATE_LIMIT_DYNAMIC_CONFIG": "get max-calls and duration dynamically through redis (default false)",
"RATE_LIMIT_DYNAMIC_CONFIG_TTL": fmt.Sprintf("time to cache dynamic config for (default %v)", defaultGCERateLimitDynamicConfigTTL),

"BACKOFF_RETRY_MAX": "Maximum allowed duration of generic exponential backoff retries (default 1m)",
"REGION": fmt.Sprintf("only takes effect when SUBNETWORK is defined; region in which to deploy (default %v)", defaultGCERegion),
"BACKOFF_RETRY_MAX": "maximum allowed duration of generic exponential backoff retries (default 1m)",
"REGION": "[REQUIRED] region in which to deploy",
"SKIP_STOP_POLL": "immediately return after issuing first instance deletion request (default false)",
"SSH_DIAL_TIMEOUT": fmt.Sprintf("connection timeout for ssh connections (default %v)", defaultGCESSHDialTimeout),
"STOP_POLL_SLEEP": fmt.Sprintf("sleep interval between polling server for instance stop status (default %v)", defaultGCEStopPollSleep),
Expand All @@ -119,7 +119,7 @@ var (
"WARMER_URL": "URL for warmer service",
"WARMER_TIMEOUT": fmt.Sprintf("timeout for requests to warmer service (default %v)", defaultGCEWarmerTimeout),
"WARMER_SSH_PASSPHRASE": fmt.Sprintf("The passphrase used to decipher instace SSH keys"),
"ZONE": fmt.Sprintf("zone name (default %q)", defaultGCEZone),
"ZONE": "zone in which to deploy job instaces into (default is to use all zones in the region)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo (instaces)

}

errGCEMissingIPAddressError = fmt.Errorf("no IP address found")
Expand Down Expand Up @@ -182,7 +182,7 @@ type gceProvider struct {
imageProjectID string
ic *gceInstanceConfig
cfg *config.ProviderConfig
alternateZones []string
allZonesForRegion []string
machineTypeSelfLinks map[string]string

backoffRetryMax time.Duration
Expand Down Expand Up @@ -212,6 +212,7 @@ type gceProvider struct {
}

type gceInstanceConfig struct {
MinimumCpuPlatform string
MachineType string
PremiumMachineType string
Zone *compute.Zone
Expand Down Expand Up @@ -257,7 +258,6 @@ type gceStartContext struct {
instanceWarmedIP string
windowsPassword string
zoneName string
zonePinned bool
machineType string
premiumMachineType string
}
Expand Down Expand Up @@ -312,6 +312,10 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) {
return nil, err
}

if !cfg.IsSet("ACCOUNT_JSON") {
return nil, fmt.Errorf("missing ACCOUNT_JSON")
}

projectID := cfg.Get("PROJECT_ID")
if metadata.OnGCE() {
projectID, err = metadata.ProjectID()
Expand All @@ -328,18 +332,32 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) {
imageProjectID = cfg.Get("IMAGE_PROJECT_ID")
}

zoneName := defaultGCEZone
if metadata.OnGCE() {
zoneName, err = metadata.Zone()
region := ""
if cfg.IsSet("REGION") {
region = cfg.Get("REGION")
} else if metadata.OnGCE() {
zoneName, err := metadata.Zone()
if err != nil {
return nil, errors.Wrap(err, "could not get zone from metadata api")
}
}
if cfg.IsSet("ZONE") {
zoneName = cfg.Get("ZONE")
zone, zErr := client.Zones.Get(projectID, zoneName).Do()
if zErr != nil {
return nil, errors.Wrap(zErr, "could not get zone from compute api")
}
zoneURI := fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/regions/(.*)", projectID)
quotedZoneURI := strings.Replace(zoneURI, "/", "\\/", -1)
re := regexp.MustCompile(quotedZoneURI)
region = re.FindStringSubmatch(zone.Region)[1]
} else {
return nil, fmt.Errorf("missing REGION")
}

cfg.Set("ZONE", zoneName)
cfg.Set("REGION", region)

minimumCpuPlatform := ""
if cfg.IsSet("MINIMUM_CPU_PLATFORM") {
minimumCpuPlatform = cfg.Get("MINIMUM_CPU_PLATFORM")
}

mtName := defaultGCEMachineType
if cfg.IsSet("MACHINE_TYPE") {
Expand Down Expand Up @@ -599,7 +617,7 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) {
projectID: projectID,
imageProjectID: imageProjectID,
cfg: cfg,
alternateZones: []string{},
allZonesForRegion: []string{},
machineTypeSelfLinks: map[string]string{},
sshDialer: sshDialer,
sshDialTimeout: sshDialTimeout,
Expand All @@ -616,6 +634,7 @@ func newGCEProvider(cfg *config.ProviderConfig) (Provider, error) {
SkipStopPoll: skipStopPoll,
Site: site,
AcceleratorConfig: defaultAcceleratorConfig,
MinimumCpuPlatform: minimumCpuPlatform,
MachineType: mtName,
PremiumMachineType: premiumMTName,
},
Expand Down Expand Up @@ -699,27 +718,27 @@ func (p *gceProvider) apiRateLimit(ctx gocontext.Context) error {
func (p *gceProvider) Setup(ctx gocontext.Context) error {
logger := context.LoggerFromContext(ctx).WithField("self", "backend/gce_provider")

logger.WithField("zone", p.cfg.Get("ZONE")).Debug("resolving configured zone")
if p.cfg.Get("ZONE") != "" {
logger.WithField("zone", p.cfg.Get("ZONE")).Debug("resolving configured zone")

err := p.backoffRetry(ctx, func() error {
p.apiRateLimit(ctx)
zone, zErr := p.client.Zones.
Get(p.projectID, p.cfg.Get("ZONE")).
Context(ctx).
Do()
if zErr == nil {
p.ic.Zone = zone
}
return zErr
})
err := p.backoffRetry(ctx, func() error {
p.apiRateLimit(ctx)
zone, zErr := p.client.Zones.Get(p.projectID, p.cfg.Get("ZONE")).Context(ctx).Do()
if zErr == nil {
logger.WithField("zone", zone).Info("resolved configured zone")
p.ic.Zone = zone
}
return zErr
})

if err != nil {
return errors.Wrap(err, "failed to resolve configured zone")
if err != nil {
return errors.Wrap(err, "failed to resolve configured zone")
}
}

logger.WithField("network", p.cfg.Get("NETWORK")).Debug("resolving configured network")

err = p.backoffRetry(ctx, func() error {
err := p.backoffRetry(ctx, func() error {
p.apiRateLimit(ctx)
nw, nwErr := p.client.Networks.
Get(p.projectID, p.cfg.Get("NETWORK")).
Expand All @@ -732,17 +751,18 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error {
})

if err != nil {
return errors.Wrap(err, "failed te resolve configured network")
return errors.Wrap(err, "failed to resolve configured network")
}

region := defaultGCERegion
if metadata.OnGCE() {
logger.WithField("region", p.ic.Zone.Region).Debug("setting region from zone when on gce")
region = p.ic.Zone.Region
}
region := ""
if p.cfg.IsSet("REGION") {
logger.WithField("region", p.ic.Zone.Region).Debug("setting region from config")
logger.WithField("region", p.cfg.Get("REGION")).Info("setting region from config")
region = p.cfg.Get("REGION")
} else if metadata.OnGCE() {
logger.WithField("region", p.ic.Zone.Region).Info("setting region from zone when on gce")
region = p.ic.Zone.Region
} else {
return errors.Wrap(err, "failed to resolve configured region")
}

if p.cfg.IsSet("SUBNETWORK") {
Expand All @@ -765,33 +785,40 @@ func (p *gceProvider) Setup(ctx gocontext.Context) error {
}
}

logger.Debug("finding alternate zones")
logger.Debug("finding all zones for region")
err = p.backoffRetry(ctx, func() error {
p.apiRateLimit(ctx)

regionURL := fmt.Sprintf("https://www.googleapis.com/compute/v1/projects/%s/regions/%s", p.projectID, p.cfg.Get("REGION"))

zl, zlErr := p.client.Zones.List(p.projectID).
Context(ctx).
Filter("status eq UP").
Filter(fmt.Sprintf("region eq %s", p.ic.Zone.Region)).Do()
Filter(fmt.Sprintf("region eq %s", regionURL)).Do()

if zlErr != nil {
return zlErr
}

p.alternateZones = []string{}
p.allZonesForRegion = []string{}
for _, z := range zl.Items {
p.alternateZones = append(p.alternateZones, z.Name)
p.allZonesForRegion = append(p.allZonesForRegion, z.Name)
}

if len(p.allZonesForRegion) == 0 {
return fmt.Errorf("no zones found for region %s", p.cfg.Get("REGION"))
}

return nil
})

if err != nil {
return errors.Wrap(err, "failed to find alternate zones")
return errors.Wrap(err, "failed to find zones for region")
}

logger.Debug("building machine type self link map")

for _, zoneName := range append([]string{p.ic.Zone.Name}, p.alternateZones...) {
for _, zoneName := range p.allZonesForRegion {
for _, machineType := range []string{p.ic.MachineType, p.ic.PremiumMachineType} {
if zoneName == "" || machineType == "" {
continue
Expand Down Expand Up @@ -915,7 +942,6 @@ func (p *gceProvider) StartWithProgress(ctx gocontext.Context, startAttributes *

c := &gceStartContext{
startAttributes: startAttributes,
zoneName: p.ic.Zone.Name,
machineType: p.ic.MachineType,
premiumMachineType: p.ic.PremiumMachineType,
progresser: progresser,
Expand Down Expand Up @@ -1051,22 +1077,7 @@ func (p *gceProvider) stepInsertInstance(c *gceStartContext) multistep.StepActio

logger := context.LoggerFromContext(c.ctx).WithField("self", "backend/gce_provider")

if c.startAttributes.VMConfig.Zone != "" {
err := p.backoffRetry(ctx, func() error {
p.apiRateLimit(ctx)
zone, zErr := p.client.Zones.Get(p.projectID, c.startAttributes.VMConfig.Zone).Context(ctx).Do()
if zErr != nil {
return zErr
}
c.zoneName = zone.Name
c.zonePinned = true
return nil
})

if err != nil {
return multistep.ActionHalt
}
}
c.zoneName = p.pickZone("")

inst, err := p.buildInstance(ctx, c)
if err != nil {
Expand Down Expand Up @@ -1148,15 +1159,13 @@ func (p *gceProvider) stepInsertInstance(c *gceStartContext) multistep.StepActio

op, insErr := p.client.Instances.Insert(p.projectID, c.zoneName, c.instance).Context(c.ctx).Do()
if insErr != nil {
if !c.zonePinned {
altZone := p.pickAlternateZone(c.zoneName)
logger.WithFields(logrus.Fields{
"err": insErr,
"prev_zone": c.zoneName,
"next_zone": altZone,
}).Warn("switching zones due to error")
p.setStartContextZone(c, altZone)
}
nextZone := p.pickZone(c.zoneName)
logger.WithFields(logrus.Fields{
"err": insErr,
"prev_zone": c.zoneName,
"next_zone": nextZone,
}).Warn("switching zones due to error")
p.setStartContextZone(c, nextZone)
return insErr
}

Expand Down Expand Up @@ -1457,6 +1466,10 @@ func (p *gceProvider) buildInstance(ctx gocontext.Context, c *gceStartContext) (
},
}

if p.ic.MinimumCpuPlatform != "" {
inst.MinCpuPlatform = p.ic.MinimumCpuPlatform
}

machineType := p.ic.MachineType
if c.startAttributes.VMType == "premium" {
machineType = p.ic.PremiumMachineType
Expand Down Expand Up @@ -1620,20 +1633,12 @@ func (p *gceProvider) warmerRequestInstance(ctx gocontext.Context, zone string,
return warmerRes, nil
}

func (p *gceProvider) pickAlternateZone(zoneName string) string {
if len(p.alternateZones) == 0 {
return zoneName
func (p *gceProvider) pickZone(zoneName string) string {
if p.cfg.Get("ZONE") != "" {
return p.cfg.Get("ZONE")
}

for {
altZone := p.alternateZones[mathrand.Intn(len(p.alternateZones))]
if altZone != zoneName {
return altZone
}
if len(p.alternateZones) == 1 {
return zoneName
}
}
return p.allZonesForRegion[mathrand.Intn(len(p.allZonesForRegion))]
}

func (p *gceProvider) setStartContextZone(c *gceStartContext, zoneName string) {
Expand Down
1 change: 1 addition & 0 deletions backend/gce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ func gceTestSetup(t *testing.T, cfg *config.ProviderConfig, resp *gceTestRespons
"PROJECT_ID": "project_id",
"IMAGE_ALIASES": "foo",
"IMAGE_ALIAS_FOO": "default",
"REGION": "us-central1",
})
}

Expand Down
2 changes: 1 addition & 1 deletion script/fmtpolice
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
set -o errexit

main() {
if [[ ! -n "${1}" ]]; then
if [[ -z "${1}" ]]; then
git ls-files '*.go' | while read -r f; do
__gofmt_check "${f}"
done
Expand Down