package imitate import ( "errors" "io/ioutil" "net/http" "os" "reflect" "strings" yaml "gopkg.in/yaml.v2" ) // YamlCurls 为了自定义序列化函数 type YamlCurls []string // UnmarshalYAML YamlCurls反序列化函数 func (curls *YamlCurls) UnmarshalYAML(unmarshal func(interface{}) error) error { var buf interface{} err := unmarshal(&buf) if err != nil { return nil } switch tbuf := buf.(type) { case string: *curls = append(*curls, parseCurl(tbuf)) case []interface{}: for _, ifa := range tbuf { *curls = append(*curls, parseCurl(ifa.(string))) } default: return errors.New("read curls is error, " + reflect.TypeOf(buf).String()) } return nil } // MarshalYAML YamlCurls序列化函数 func (curls *YamlCurls) MarshalYAML() (interface{}, error) { content := "[" for _, curl := range []string(*curls) { content += "\"" + curl + "\"" + ", " } content = strings.TrimRight(content, ", ") content += "]" return content, nil } // YamlProxies 为了自定义序列化函数 type YamlProxies []string // UnmarshalYAML YamlProxies反序列化函数 func (proxies *YamlProxies) UnmarshalYAML(unmarshal func(interface{}) error) error { var buf interface{} err := unmarshal(&buf) if err != nil { return nil } switch tbuf := buf.(type) { case string: *proxies = append(*proxies, tbuf) case []interface{}: for _, ifa := range tbuf { *proxies = append(*proxies, ifa.(string)) } default: return errors.New("read curls is error, " + reflect.TypeOf(buf).String()) } return nil } // MarshalYAML YamlProxies 序列化函数 func (proxies *YamlProxies) MarshalYAML() (interface{}, error) { content := "[" for _, curl := range []string(*proxies) { content += "\"" + curl + "\"" + ", " } content = strings.TrimRight(content, ", ") content += "]" return content, nil } // Config 任务加载的默认配置 type Config struct { // Session int `yaml:"session"` Mode int `yaml:"mode"` Proxies YamlProxies `yaml:"proxies"` Retry int `yaml:"retry"` Priority int `yaml:"priority"` Curls YamlCurls `yaml:"curls"` Crontab string `yaml:"crontab"` Device string `yaml:"device"` Platform string `yaml:"platform"` AreaCC string `yaml:"area_cc"` Channel int `yaml:"channel"` Media int `yaml:"media"` SpiderID int `yaml:"spider_id"` CatchAccountID int `yaml:"catch_account_id"` } // newDefaultConfig create a default config func newDefaultConfig() *Config { conf := &Config{ // Session: 1, Mode: 0, Retry: 0, Priority: 10000, Crontab: "", Device: "", Platform: "", AreaCC: "", Channel: -1, Media: -1, SpiderID: -1, CatchAccountID: -1, } return conf } // NewConfig 加载并返回Config func NewConfig(p string) *Config { f, err := os.Open(p) defer f.Close() if err != nil { panic(err) } conf := newDefaultConfig() err = yaml.NewDecoder(f).Decode(conf) if err != nil { panic(err) } return conf } func parseCurl(curl string) string { switch curl[0] { case '@': curlfile, err := os.Open(curl[1:]) defer curlfile.Close() if err != nil { panic(err) } curldata, err := ioutil.ReadAll(curlfile) return strings.Trim(string(curldata), "\r\n ") case '#': resp, err := http.Get(curl[1:]) if err != nil { panic(err) } curldata, err := ioutil.ReadAll(resp.Body) if err != nil { panic(err) } return strings.Trim(string(curldata), "\r\n ") } return strings.Trim(curl, "\r\n ") }