package SunnyProxy

import (
	"Sunny/src/crypto/pkcs"
	"bytes"
)
import (
	"Sunny/src/crypto/tls"
	"crypto/rand"
	"crypto/rsa"
	TLS "crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"errors"
	"io/ioutil"
	"math/big"
	"net"
	"os"
	"strings"
	"sync"
	"time"
)

var (
	rootCa  *x509.Certificate // CA证书
	rootKey *rsa.PrivateKey   // 证书私钥
)

var certCache *Cache
var root_Key []byte
var root_Ca []byte

func Init(rootCa, rootKey string) {
	root_Ca = []byte(rootCa)
	root_Key = []byte(rootKey)
	certCache = NewCache()
	if err := loadRootCa(); err != nil {
		panic(err)
	}
	if err := loadRootKey(); err != nil {
		panic(err)
	}
}
func SplitHostPort(ip string) (host, port string, err error) {
	arr := strings.Split(ip, ":")
	if len(arr) < 3 {
		return net.SplitHostPort(ip)
	}
	/*
		s := "["
		for i := 0; i < len(arr)-1; i++ {
			s += arr[i] + ":"
		}
		s = s[0:len(s)-1] + "]:" + arr[len(arr)-1]
		return net.SplitHostPort(s)
	*/
	return "", "", errors.New(" 不支持 IPv6 ")
}
func GetCertificate(xhost string) (*tls.Certificate, error) {
	c2 := LoadP12Certificate(xhost)
	if c2 != nil {
		return c2.SunnyTLSCert, nil
	}
	certificate, err := certCache.GetOrStore(xhost, func() (interface{}, error) {
		host, _, err := SplitHostPort(xhost)
		if err != nil {
			return nil, err
		}
		c := LoadP12Certificate(host)
		if c != nil {
			if c.SunnyTLSCert != nil {
				return *c.SunnyTLSCert, nil
			}
		}
		certByte, priByte, err := generatePem(host)
		if err != nil {
			return nil, err
		}
		certificate, err := tls.X509KeyPair(certByte, priByte)
		if err != nil {
			return nil, err
		}
		return certificate, nil
	})
	if certificate == nil {
		return nil, err
	}
	i := certificate.(tls.Certificate)
	return &i, err
}

type Cers struct {
	SunnyTLSCert  *tls.Certificate
	CryptoTLSCert *TLS.Certificate
	SunnyTLSCert2 *tls.Certificate
	cerType       int
}

var CertificateMap = make(map[string]*Cers)
var CertificateMapLock sync.Mutex

func LoadP12Certificate(HostName string) *Cers {
	CertificateMapLock.Lock()
	defer CertificateMapLock.Unlock()
	arr := strings.Split(HostName, ".")
	if len(arr) >= 2 {
		s := arr[len(arr)-2] + "." + arr[len(arr)-1]
		arr = strings.Split(s, ":")
		if len(arr) == 2 {
			s = arr[0]
		}
		ce := CertificateMap[s]
		if ce == nil {
			return nil
		}
		return ce
	}
	ce := CertificateMap[HostName]
	if ce == nil {
		return nil
	}
	return ce
}
func DelP12Certificate(HostName string) {
	CertificateMapLock.Lock()
	defer CertificateMapLock.Unlock()
	delete(CertificateMap, HostName)
}
func AddP12Certificate(HostName, privateKeyName, privatePassword string, cerType int) bool {
	k, _ := getPrivateKey(privateKeyName, privatePassword)
	if k == nil {
		return false
	}
	var pemData []byte
	for _, b := range k {
		pemData = append(pemData, pem.EncodeToMemory(b)...)
	}
	ce, err := tls.X509KeyPair(pemData, pemData)
	if err != nil {
		return false
	}
	ce1, err := TLS.X509KeyPair(pemData, pemData)
	if err != nil {
		return false
	}
	CertificateMapLock.Lock()
	defer CertificateMapLock.Unlock()
	arr := strings.Split(HostName, ".")
	sx := new(Cers)
	if cerType == 0 {
		sx.SunnyTLSCert = &ce
	} else if cerType == 1 {
		sx.CryptoTLSCert = &ce1
		sx.SunnyTLSCert2 = &ce
	} else {
		sx.SunnyTLSCert = &ce
		sx.SunnyTLSCert2 = &ce
		sx.CryptoTLSCert = &ce1
	}
	if len(arr) >= 2 {
		s := arr[len(arr)-2] + "." + arr[len(arr)-1]
		arr = strings.Split(s, ":")
		if len(arr) == 2 {
			s = arr[0]
		}
		CertificateMap[s] = sx
	} else {
		CertificateMap[HostName] = sx
	}
	return true
}

var NullBytes []byte

func getPrivateKey(privateKeyName, privatePassword string) ([]*pem.Block, error) {
	f, err := os.Open(privateKeyName)
	if err != nil {
		return nil, err
	}

	bytes, err := ioutil.ReadAll(f)
	if err != nil {
		return nil, err
	}

	A, C := pkcs.ToPEM(bytes, privatePassword)
	if A == nil {
		return nil, C
	}
	return A, C
	/*
		// 因为pfx证书公钥和密钥是成对的，所以要先转成pem.Block
		blocks, err := pkcs.ToPEM(bytes, privatePassword)
		if err != nil {
			return nil, NullBytes, err
		}
		if len(blocks) < 2 {
			return nil, NullBytes, errors.New("解密错误")
		}
		// 拿到第一个block，用x509解析出私钥（当然公钥也是可以的）
		privateKey, err := x509.ParsePKCS1PrivateKey(blocks[0].Bytes)
		if err != nil {
			return nil, NullBytes, err
		}
		return privateKey, blocks[1].Bytes, nil

	*/
}

// 秘钥对 生成一对具有指定字位数的RSA密钥
func generateKeyPair() (*rsa.PrivateKey, error) {
	return rootKey, nil
}

var max = new(big.Int).Lsh(big.NewInt(1), 128)

func generatePem(host string) ([]byte, []byte, error) {
	//max := new(big.Int).Lsh(big.NewInt(1), 128)   //把 1 左移 128 位，返回给 big.Int
	serialNumber, _ := rand.Int(rand.Reader, max) //返回在 [0, max) 区间均匀随机分布的一个随机值
	template := x509.Certificate{
		SerialNumber: serialNumber, // SerialNumber 是 CA 颁布的唯一序列号，在此使用一个大随机数来代表它
		Subject: pkix.Name{ //Name代表一个X.509识别名。只包含识别名的公共属性，额外的属性被忽略。
			CommonName: host,
		},
		NotBefore:      time.Now().AddDate(0, 0, -1),
		NotAfter:       time.Now().AddDate(0, 0, 365),
		KeyUsage:       x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, //KeyUsage 与 ExtKeyUsage 用来表明该证书是用来做服务器认证的
		ExtKeyUsage:    []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},               // 密钥扩展用途的序列
		EmailAddresses: []string{"forward.nice.cp@gmail.com"},
	}

	if ip := net.ParseIP(host); ip != nil {
		template.IPAddresses = []net.IP{ip}
	} else {
		template.DNSNames = []string{host}
	}

	priKey, err := generateKeyPair()
	if err != nil {
		return nil, nil, err
	}

	cer, err := x509.CreateCertificate(rand.Reader, &template, rootCa, &priKey.PublicKey, rootKey)
	if err != nil {
		return nil, nil, err
	}

	return pem.EncodeToMemory(&pem.Block{ // 证书
			Type:  "CERTIFICATE",
			Bytes: cer,
		}), pem.EncodeToMemory(&pem.Block{ // 私钥
			Type:  "RSA PRIVATE KEY",
			Bytes: x509.MarshalPKCS1PrivateKey(priKey),
		}), err
}

// 加载根证书
func loadRootCa() error {
	var err error
	p, _ := pem.Decode(root_Ca)
	rootCa, err = x509.ParseCertificate(p.Bytes)
	if err != nil {
		return errors.New("CA证书解析失败")
	}

	return nil
}

// 加载根Private Key
func loadRootKey() error {
	p, _ := pem.Decode(root_Key)
	var err error
	rootKey, err = x509.ParsePKCS1PrivateKey(p.Bytes)
	if err != nil {
		return errors.New("Key证书解析失败")
	}
	return err
}

// 获取证书原内容
func GetCaCert() []byte {
	ar := strings.Split(strings.ReplaceAll(string(root_Ca), "\r", ""), "\n")
	var b bytes.Buffer
	for _, v := range ar {
		if strings.Index(v, ": ") == -1 && len(v) > 0 {
			b.WriteString(v + "\r\n")
		}
	}
	return b.Bytes()
}
