From 9f9cd3fec18c755379d48e1aa68c9ee886ff89e7 Mon Sep 17 00:00:00 2001 From: fuxiao Date: Mon, 16 Aug 2021 18:30:13 +0800 Subject: [PATCH] first commit --- client.go | 196 ++++++++++++++++++++++++++++++ client_test.go | 81 ++++++++++++ go.mod | 1 + internal/utils.go | 304 ++++++++++++++++++++++++++++++++++++++++++++++ middleware.go | 32 +++++ request.go | 264 ++++++++++++++++++++++++++++++++++++++++ response.go | 116 ++++++++++++++++++ 7 files changed, 994 insertions(+) create mode 100644 client.go create mode 100644 client_test.go create mode 100644 go.mod create mode 100644 internal/utils.go create mode 100644 middleware.go create mode 100644 request.go create mode 100644 response.go diff --git a/client.go b/client.go new file mode 100644 index 0000000..58b08db --- /dev/null +++ b/client.go @@ -0,0 +1,196 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2021/8/14 4:11 下午 + * @Desc: TODO + */ + +package http + +import ( + "context" + "crypto/tls" + "encoding/base64" + "net/http" + "net/http/cookiejar" + "time" +) + +type Client struct { + http.Client + headers map[string]string + cookies map[string]string + ctx context.Context + baseUrl string + retryCount int + retryInterval time.Duration + middlewares []MiddlewareFunc +} + +const ( + defaultUserAgent = "DobyteHttpClient" + + HeaderUserAgent = "User-Agent" + HeaderContentType = "Content-Type" + HeaderAuthorization = "Authorization" + HeaderCookie = "Cookie" + HeaderHost = "Host" + + ContentTypeJson = "application/json" + ContentTypeXml = "application/xml" + ContentTypeFormData = "form-data" + ContentTypeFormUrlEncoded = "application/x-www-form-urlencoded" +) + +func NewClient() *Client { + client := &Client{ + Client: http.Client{ + Transport: &http.Transport{ + DisableKeepAlives: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + }, + headers: make(map[string]string), + cookies: make(map[string]string), + middlewares: make([]MiddlewareFunc, 0), + } + client.headers[HeaderUserAgent] = defaultUserAgent + + return client +} + +// Set a header for the request. +func (c *Client) SetHeader(key, value string) *Client { + c.headers[key] = value + return c +} + +// Set multiple headers for the request. +func (c *Client) SetHeaders(headers map[string]string) *Client { + for key, value := range headers { + c.headers[key] = value + } + return c +} + +// Set a cookie for the request. +func (c *Client) SetCookie(key, value string) *Client { + c.cookies[key] = value + return c +} + +// Set multiple cookies for the request. +func (c *Client) SetCookies(cookies map[string]string) *Client { + for key, value := range cookies { + c.cookies[key] = value + } + return c +} + +// Set User-Agent for the request. +func (c *Client) SetUserAgent(agent string) *Client { + c.headers[HeaderUserAgent] = agent + return c +} + +// Set Content-Type for the request. +func (c *Client) SetContentType(contentType string) *Client { + c.headers[HeaderContentType] = contentType + return c +} + +// Enable browser mode for the request. +func (c *Client) SetBrowserMode() *Client { + jar, _ := cookiejar.New(nil) + c.Jar = jar + return c +} + +// +func (c *Client) SetBaseUrl(baseUrl string) *Client { + c.baseUrl = baseUrl + return c +} + +// SetBasicAuth set HTTP basic authentication information for the request. +func (c *Client) SetBasicAuth(username, password string) *Client { + c.headers[HeaderAuthorization] = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) + return c +} + +// SetBearerToken set HTTP Bearer-Token authentication information for the request. +func (c *Client) SetBearerToken(token string) *Client { + c.headers[HeaderAuthorization] = "Bearer " + token + return c +} + +// SetContext set context for the request. +func (c *Client) SetContext(ctx context.Context) *Client { + c.ctx = ctx + return c +} + +// SetTimeOut sets the request timeout for the client. +func (c *Client) SetTimeout(timeout time.Duration) *Client { + c.Client.Timeout = timeout + return c +} + +// SetRetry sets count and interval of retry for the request. +func (c *Client) SetRetry(retryCount int, retryInterval time.Duration) *Client { + c.retryCount = retryCount + c.retryInterval = retryInterval + return c +} + +func (c *Client) SetKeepAlive(enable bool) { + //c.Transport. +} + +// Use sets middleware for the request. +func (c *Client) Use(middlewares ...MiddlewareFunc) *Client { + c.middlewares = append(c.middlewares, middlewares...) + return c +} + +func (c *Client) Request(method, url string, data ...interface{}) (*Response, error) { + return NewRequest(c).request(method, url, data...) +} + +func (c *Client) Get(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodGet, url, data...) +} + +func (c *Client) Post(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodPost, url, data...) +} + +func (c *Client) Put(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodPut, url, data...) +} + +func (c *Client) Patch(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodPatch, url, data...) +} + +func (c *Client) Delete(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodDelete, url, data...) +} + +func (c *Client) Head(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodHead, url, data...) +} + +func (c *Client) Options(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodOptions, url, data...) +} + +func (c *Client) Connect(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodConnect, url, data...) +} + +func (c *Client) Trace(url string, data ...interface{}) (*Response, error) { + return c.Request(MethodTrace, url, data...) +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..f1bb7e0 --- /dev/null +++ b/client_test.go @@ -0,0 +1,81 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2021/8/16 2:54 下午 + * @Desc: TODO + */ + +package http_test + +import ( + "errors" + "testing" + + "github.com/dobyte/http" +) + +const token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlc2MiOjE2MjgwNDAzMjYxNTQ2MzIwMDAsImV4cCI6MTYyODIyMDMyNiwiaWF0IjoxNjI4MDQwMzI2LCJpZCI6MX0.KM19c6URIih-5SyycYIjNAdSiPKxMQEz3DoROm0N3nw" + +func TestClient_Request(t *testing.T) { + client := http.NewClient() + client.SetBaseUrl("http://127.0.0.1:8199").Use(func(r *http.Request) (*http.Response, error) { + return r.Next() + }).Use(func(r *http.Request) (*http.Response, error) { + return nil, errors.New("Invalid params.") + }) + + resp, err := client.Request(http.MethodGet, "/common/regions") + if err != nil { + t.Error(err) + return + } + + t.Log(resp.Response.Status) +} + +func TestClient_Post(t *testing.T) { + client := http.NewClient() + client.SetBaseUrl("http://127.0.0.1:8199") + client.SetBearerToken(token) + client.SetContentType(http.ContentTypeJson) + client.Use(func(r *http.Request) (*http.Response, error) { + r.Request.Header.Set("Client-Type", "2") + return r.Next() + }) + + type updateRegionArg struct { + Id int `json:"id"` + Pid int `json:"pid"` + Code string `json:"code"` + Name string `json:"name"` + Sort int `json:"sort"` + } + + data := updateRegionArg{ + Id: 1, + Pid: 0, + Code: "110000", + Name: "北京市", + Sort: 0, + } + + //data := map[string]interface{}{ + // "id": 1, + // "pid": 0, + // "code": "110000", + // "name": "北京市", + // "sort": 0, + //} + + if resp, err := client.Put("/backend/region/update-region", data); err != nil { + t.Error(err) + return + } else { + t.Log(resp.Response.Status) + t.Log(resp.Response.Header) + t.Log(resp.Bytes()) + t.Log(resp.String()) + t.Log(resp.GetHeaders()) + t.Log(resp.GetCookies()) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..6e28890 --- /dev/null +++ b/go.mod @@ -0,0 +1 @@ +module "github.com/dobyte/http" \ No newline at end of file diff --git a/internal/utils.go b/internal/utils.go new file mode 100644 index 0000000..affe7c7 --- /dev/null +++ b/internal/utils.go @@ -0,0 +1,304 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2021/8/16 3:47 下午 + * @Desc: TODO + */ + +package internal + +import ( + "encoding" + "encoding/json" + "fmt" + "net/url" + "os" + "reflect" + "strconv" + "strings" + "time" + "unsafe" +) + +const fileUploadingKey = "@file:" + +func Exists(path string) bool { + if stat, err := os.Stat(path); stat != nil && !os.IsNotExist(err) { + return true + } + return false +} + +func BuildParams(params interface{}) string { + switch v := params.(type) { + case string: + return v + case []byte: + return string(v) + case []interface{}: + if len(v) > 0 { + params = v[0] + } else { + params = nil + } + } + + m := make(map[string]interface{}) + + if params != nil { + if b, err := json.Marshal(params); err != nil { + return String(params) + } else if err = json.Unmarshal(b, &m); err != nil { + return String(params) + } + } else { + return "" + } + + urlEncode := true + + if len(m) == 0 { + return String(params) + } + + for k, v := range m { + if strings.Contains(k, fileUploadingKey) || strings.Contains(String(v), fileUploadingKey) { + urlEncode = false + break + } + } + + var ( + s = "" + str = "" + ) + + for k, v := range m { + if len(str) > 0 { + str += "&" + } + s = String(v) + if urlEncode && len(s) > len(fileUploadingKey) && strings.Compare(s[0:len(fileUploadingKey)], fileUploadingKey) != 0 { + s = url.QueryEscape(s) + } + str += k + "=" + s + } + + return str +} + +func String(any interface{}) string { + switch v := any.(type) { + case nil: + return "" + case string: + return v + case int: + return strconv.Itoa(v) + case int8: + return strconv.Itoa(int(v)) + case int16: + return strconv.Itoa(int(v)) + case int32: + return strconv.Itoa(int(v)) + case int64: + return strconv.FormatInt(v, 10) + case uint: + return strconv.FormatUint(uint64(v), 10) + case uint8: + return strconv.FormatUint(uint64(v), 10) + case uint16: + return strconv.FormatUint(uint64(v), 10) + case uint64: + return strconv.FormatUint(v, 10) + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case bool: + return strconv.FormatBool(v) + case []byte: + return string(v) + case time.Time: + return v.String() + case *time.Time: + if v == nil { + return "" + } + return v.String() + default: + if v == nil { + return "" + } + + if i, ok := v.(stringInterface); ok { + return i.String() + } + + if i, ok := v.(errorInterface); ok { + return i.Error() + } + + var ( + rv = reflect.ValueOf(v) + kind = rv.Kind() + ) + + switch kind { + case reflect.Chan, + reflect.Map, + reflect.Slice, + reflect.Func, + reflect.Ptr, + reflect.Interface, + reflect.UnsafePointer: + if rv.IsNil() { + return "" + } + case reflect.String: + return rv.String() + } + + if kind == reflect.Ptr { + return String(rv.Elem().Interface()) + } + + if b, e := json.Marshal(v); e != nil { + return fmt.Sprint(v) + } else { + return string(b) + } + } +} + +func Scan(b []byte, any interface{}) error { + switch v := any.(type) { + case nil: + return fmt.Errorf("cache: Scan(nil)") + case *string: + *v = String(b) + return nil + case *[]byte: + *v = b + return nil + case *int: + var err error + *v, err = strconv.Atoi(String(b)) + return err + case *int8: + n, err := strconv.ParseInt(String(b), 10, 8) + if err != nil { + return err + } + *v = int8(n) + return nil + case *int16: + n, err := strconv.ParseInt(String(b), 10, 16) + if err != nil { + return err + } + *v = int16(n) + return nil + case *int32: + n, err := strconv.ParseInt(String(b), 10, 32) + if err != nil { + return err + } + *v = int32(n) + return nil + case *int64: + n, err := strconv.ParseInt(String(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *uint: + n, err := strconv.ParseUint(String(b), 10, 64) + if err != nil { + return err + } + *v = uint(n) + return nil + case *uint8: + n, err := strconv.ParseUint(String(b), 10, 8) + if err != nil { + return err + } + *v = uint8(n) + return nil + case *uint16: + n, err := strconv.ParseUint(String(b), 10, 16) + if err != nil { + return err + } + *v = uint16(n) + return nil + case *uint32: + n, err := strconv.ParseUint(String(b), 10, 32) + if err != nil { + return err + } + *v = uint32(n) + return nil + case *uint64: + n, err := strconv.ParseUint(String(b), 10, 64) + if err != nil { + return err + } + *v = n + return nil + case *float32: + n, err := strconv.ParseFloat(String(b), 32) + if err != nil { + return err + } + *v = float32(n) + return err + case *float64: + var err error + *v, err = strconv.ParseFloat(String(b), 64) + return err + case *bool: + *v = len(b) == 1 && b[0] == '1' + return nil + case *time.Time: + var err error + *v, err = time.Parse(time.RFC3339Nano, String(b)) + return err + case encoding.BinaryUnmarshaler: + return v.UnmarshalBinary(b) + default: + var ( + rv = reflect.ValueOf(v) + kind = rv.Kind() + ) + + if kind != reflect.Ptr { + return fmt.Errorf("can't unmarshal %T", v) + } + + switch kind = rv.Elem().Kind(); kind { + case reflect.Array, reflect.Slice, reflect.Map, reflect.Struct: + return json.Unmarshal(b, v) + } + + return fmt.Errorf("can't unmarshal %T", v) + } +} + +type stringInterface interface { + String() string +} + +type errorInterface interface { + Error() string +} + +func UnsafeStringToBytes(s string) []byte { + return *(*[]byte)(unsafe.Pointer(&s)) +} + +func UnsafeBytesToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..a890f78 --- /dev/null +++ b/middleware.go @@ -0,0 +1,32 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2021/8/16 9:47 上午 + * @Desc: request's middleware + */ + +package http + +type MiddlewareFunc = func(r *Request) (*Response, error) + +const middlewareKey = "__httpClientMiddlewareKey" + +type middleware struct { + err error + req *Request + resp *Response + index int + handlers []MiddlewareFunc +} + +// Next exec the next middleware. +func (m *middleware) Next() (*Response, error) { + if m.index < len(m.handlers) { + m.index++ + if m.resp, m.err = m.handlers[m.index](m.req); m.err != nil { + return m.resp, m.err + } + } + + return m.resp, m.err +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..25080fa --- /dev/null +++ b/request.go @@ -0,0 +1,264 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2021/8/16 9:40 上午 + * @Desc: TODO + */ + +package http + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + "log" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/dobyte/http/internal" +) + +const ( + MethodGet = http.MethodGet + MethodHead = http.MethodHead + MethodPost = http.MethodPost + MethodPut = http.MethodPut + MethodPatch = http.MethodPatch + MethodDelete = http.MethodDelete + MethodConnect = http.MethodConnect + MethodOptions = http.MethodOptions + MethodTrace = http.MethodTrace +) + +const fileUploadingKey = "@file:" + +type Request struct { + client *Client + retryCount int + retryInterval time.Duration + Request *http.Request +} + +func NewRequest(c *Client) *Request { + return &Request{ + client: c, + retryCount: c.retryCount, + retryInterval: c.retryInterval, + } +} + +func (r *Request) Next() (*Response, error) { + if v := r.Request.Context().Value(middlewareKey); v != nil { + if m, ok := v.(*middleware); ok { + return m.Next() + } + } + return r.call() +} + +func (r *Request) request(method, url string, data ...interface{}) (resp *Response, err error) { + r.Request, err = r.prepare(method, url, data...) + if err != nil { + return nil, err + } + + if count := len(r.client.middlewares); count > 0 { + handlers := make([]MiddlewareFunc, 0, count+1) + handlers = append(handlers, r.client.middlewares...) + handlers = append(handlers, func(r *Request) (*Response, error) { + return r.call() + }) + r.Request = r.Request.WithContext(context.WithValue(r.Request.Context(), middlewareKey, &middleware{ + req: r, + handlers: handlers, + index: -1, + })) + resp, err = r.Next() + } else { + resp, err = r.call() + } + + return resp, err +} + +// prepare build a http request. +func (r *Request) prepare(method, url string, data ...interface{}) (req *http.Request, err error) { + method = strings.ToUpper(method) + url = r.client.baseUrl + url + + var params string + if len(data) > 0 { + switch data[0].(type) { + case string: + params = data[0].(string) + case []byte: + params = string(data[0].([]byte)) + default: + switch r.client.headers[HeaderContentType] { + case ContentTypeJson: + if b, err := json.Marshal(data[0]); err != nil { + return nil, err + } else { + params = string(b) + } + case ContentTypeXml: + if b, err := xml.Marshal(data[0]); err != nil { + return nil, err + } else { + params = string(b) + } + default: + params = internal.BuildParams(data[0]) + } + } + } + + if method == MethodGet { + buffer := bytes.NewBuffer(nil) + + if params != "" { + switch r.client.headers[HeaderContentType] { + case ContentTypeJson, ContentTypeXml: + buffer = bytes.NewBuffer([]byte(params)) + default: + if strings.Contains(url, "?") { + url = url + "&" + params + } else { + url = url + "?" + params + } + } + } + + if req, err = http.NewRequest(method, url, buffer); err != nil { + return nil, err + } + } else { + if strings.Contains(params, fileUploadingKey) { + var ( + buffer = bytes.NewBuffer(nil) + writer = multipart.NewWriter(buffer) + ) + + for _, item := range strings.Split(params, "&") { + array := strings.Split(item, "=") + if len(array[1]) > 6 && strings.Compare(array[1][0:6], fileUploadingKey) == 0 { + path := array[1][6:] + if !internal.Exists(path) { + return nil, errors.New(fmt.Sprintf(`"%s" does not exist`, path)) + } + if file, err := writer.CreateFormFile(array[0], filepath.Base(path)); err == nil { + if f, err := os.Open(path); err == nil { + if _, err = io.Copy(file, f); err != nil { + if err := f.Close(); err != nil { + log.Printf(`%+v`, err) + } + return nil, err + } + if err := f.Close(); err != nil { + log.Printf(`%+v`, err) + } + } else { + return nil, err + } + } else { + return nil, err + } + } else { + if err = writer.WriteField(array[0], array[1]); err != nil { + return nil, err + } + } + } + + if err = writer.Close(); err != nil { + return nil, err + } + + if req, err = http.NewRequest(method, url, buffer); err != nil { + return nil, err + } else { + req.Header.Set(HeaderContentType, writer.FormDataContentType()) + } + } else { + paramBytes := []byte(params) + if req, err = http.NewRequest(method, url, bytes.NewReader(paramBytes)); err != nil { + return nil, err + } else { + if v, ok := r.client.headers[HeaderContentType]; ok { + req.Header.Set(HeaderContentType, v) + } else if len(paramBytes) > 0 { + if (paramBytes[0] == '[' || paramBytes[0] == '{') && json.Valid(paramBytes) { + req.Header.Set(HeaderContentType, ContentTypeJson) + } else if matched, _ := regexp.Match(`^[\w\[\]]+=.+`, paramBytes); matched { + req.Header.Set(HeaderContentType, ContentTypeFormUrlEncoded) + } + } + } + } + } + + if r.client.ctx != nil { + req = req.WithContext(r.client.ctx) + } else { + req = req.WithContext(context.Background()) + } + + if len(r.client.headers) > 0 { + for key, value := range r.client.headers { + if key != "" { + req.Header.Set(key, value) + } + } + } + + if len(r.client.cookies) > 0 { + var cookies = make([]string, 0) + for key, value := range r.client.cookies { + if key != "" { + cookies = append(cookies, key+"="+value) + } + } + req.Header.Set(HeaderCookie, strings.Join(cookies, ";")) + } + + if host := req.Header.Get(HeaderHost); host != "" { + req.Host = host + } + + return req, nil +} + +// call nitiate an HTTP request and return the response data. +func (r *Request) call() (resp *Response, err error) { + resp = &Response{Request: r.Request} + + for { + if resp.Response, err = r.client.Do(r.Request); err != nil { + if resp.Response != nil { + if err := resp.Response.Body.Close(); err != nil { + log.Printf(`%+v`, err) + } + } + + if r.retryCount > 0 { + r.retryCount-- + time.Sleep(r.retryInterval) + } else { + break + } + } else { + break + } + } + + return resp, err +} diff --git a/response.go b/response.go new file mode 100644 index 0000000..793072f --- /dev/null +++ b/response.go @@ -0,0 +1,116 @@ +/** + * @Author: fuxiao + * @Email: 576101059@qq.com + * @Date: 2021/8/15 4:56 下午 + * @Desc: TODO + */ + +package http + +import ( + "io/ioutil" + "net/http" + + "github.com/dobyte/http/internal" +) + +type Response struct { + *http.Response + Request *http.Request + body []byte + cookies map[string]string +} + +// Bytes retrieves and returns the response content as []byte. +func (r *Response) Bytes() []byte { + if r == nil || r.Response == nil { + return []byte{} + } + + if r.body == nil { + var err error + if r.body, err = ioutil.ReadAll(r.Response.Body); err != nil { + return nil + } + } + + return r.body +} + +// String retrieves and returns the response content as string. +func (r *Response) String() string { + return internal.UnsafeBytesToString(r.Bytes()) +} + +// Scan convert the response into a complex data structure. +func (r *Response) Scan(any interface{}) error { + return internal.Scan(r.Bytes(), any) +} + +// Close closes the response when it will never be used. +func (r *Response) Close() error { + if r == nil || r.Response == nil || r.Response.Close { + return nil + } + r.Response.Close = true + return r.Response.Body.Close() +} + +// HasHeader Determine if a header exists in the cache. +func (r *Response) HasHeader(key string) bool { + for k, _ := range r.Header { + if k == key { + return true + } + } + + return false +} + +// GetHeader Retrieve header's value from the response. +func (r *Response) GetHeader(key string) string { + return r.Header.Get(key) +} + +// GetHeader Retrieve all header's value from the response. +func (r *Response) GetHeaders() map[string]interface{} { + headers := make(map[string]interface{}) + for k, v := range r.Header { + if len(v) > 1 { + headers[k] = v + } else { + headers[k] = v[0] + } + } + + return headers +} + +// HasCookie Determine if a cookie exists in the cache. +func (r *Response) HasCookie(key string) bool { + if r.cookies == nil { + r.cookies = r.GetCookies() + } + _, ok := r.cookies[key] + + return ok +} + +// GetCookie Retrieve cookie's value from the response. +func (r *Response) GetCookie(key string) string { + if r.cookies == nil { + r.cookies = r.GetCookies() + } + return r.cookies[key] +} + +// GetCookies Retrieve all cookie's value from the response. +func (r *Response) GetCookies() map[string]string { + cookies := make(map[string]string) + if r != nil && r.Response != nil { + for _, cookie := range r.Cookies() { + cookies[cookie.Name] = cookie.Value + } + } + return cookies +}