106 lines
2.1 KiB
Go
106 lines
2.1 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package downloader
|
|
|
|
import (
|
|
"bytes"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"os"
|
|
|
|
"github.com/hashicorp/go-retryablehttp"
|
|
"github.com/hashicorp/vagrant-plugin-sdk/component"
|
|
)
|
|
|
|
// Type is an enum of all the available http methods
|
|
type HTTPMethod int64
|
|
|
|
const (
|
|
GET HTTPMethod = iota
|
|
DELETE
|
|
HEAD
|
|
POST
|
|
PUT
|
|
)
|
|
|
|
type Downloader struct {
|
|
config DownloaderConfig
|
|
}
|
|
|
|
type DownloaderConfig struct {
|
|
Dest string
|
|
Headers http.Header
|
|
Method HTTPMethod
|
|
RetryCount int
|
|
RequestBody []byte
|
|
Src string
|
|
UrlQueryParams map[string]string
|
|
}
|
|
|
|
// Config implements Configurable
|
|
func (d *Downloader) Config() (interface{}, error) {
|
|
return &d.config, nil
|
|
}
|
|
|
|
func (d *Downloader) DownloadFunc() interface{} {
|
|
return d.Download
|
|
}
|
|
|
|
func (d *Downloader) Download() (err error) {
|
|
client := retryablehttp.NewClient()
|
|
client.RetryMax = d.config.RetryCount
|
|
var req *retryablehttp.Request
|
|
|
|
// Create request with request body if one is provided
|
|
if d.config.RequestBody != nil {
|
|
req, err = retryablehttp.NewRequest(
|
|
d.config.Method.String(), d.config.Src, bytes.NewBuffer(d.config.RequestBody),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
// If no request body is provided then create an empty request
|
|
req, err = retryablehttp.NewRequest(
|
|
d.config.Method.String(), d.config.Src, nil,
|
|
)
|
|
}
|
|
|
|
// Add query params if provided
|
|
if d.config.UrlQueryParams != nil {
|
|
q := req.URL.Query()
|
|
for k, v := range d.config.UrlQueryParams {
|
|
q.Add(k, v)
|
|
}
|
|
req.URL.RawQuery = q.Encode()
|
|
}
|
|
|
|
// Set headers
|
|
req.Header = d.config.Headers
|
|
// Add headers to redirects
|
|
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
|
for key, val := range via[0].Header {
|
|
req.Header[key] = val
|
|
}
|
|
return err
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
data, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
err = os.WriteFile(d.config.Dest, data, 0644)
|
|
return
|
|
}
|
|
|
|
var (
|
|
_ component.Downloader = (*Downloader)(nil)
|
|
_ component.Configurable = (*Downloader)(nil)
|
|
)
|