问题:
最近有用到golang自带的rpc,到后面需要获取下请求接口的远程客户端Ip,查了下资料都是建议使用第三方库的,既然如此 不如自己动手了.

环境
go 1.19.2

准备工作:

package main

import (
    "net"
    "net/rpc"
    "net/rpc/jsonrpc"
)

// 用于测试的接口
type TestInterface struct {
}

func (t *TestInterface) Test(number int, rep *string) error {
    //这里获取ip...????
    return nil
}

func main() {
    //这里注册到DefaultServer
    err := rpc.RegisterName("TEST", &TestInterface{})
    if err != nil {
        panic(err)
    }
    //开启rpc服务监听
    lis, err := net.Listen("tcp4", "0.0.0.0:5588")
    if err != nil {
        panic(err)
    }
    for {

        conn, err := lis.Accept()
        if err != nil {
            break
        }

        go func(conn net.Conn) {
            defer func() {
                conn.Close()
            }()
            //处理rpc 调用
            jsonrpc.ServeConn(conn)
        }(conn)

    }
}

这是一个很简单的json协议下的 go rpc用例 简单翻阅了下它的库代码 发现改造起来很简单
发现需要修改
rpc\server.go
jsonrpc\server.go(rpc\server的包装)

索性直接把它们两个文件复制出来 这样改动也不会影响原来的go库

改造jsonrpc\server.go
//第一个

type serverCodec struct {
    dec *json.Decoder // for reading JSON values
    enc *json.Encoder // for writing JSON values
    c   io.Closer

    //改造
    conn net.Conn

    // temporary work space
    req serverRequest

    // JSON-RPC clients can use arbitrary json values as request IDs.
    // Package rpc expects uint64 request IDs.
    // We assign uint64 sequence numbers to incoming requests
    // but save the original request ID in the pending map.
    // When rpc responds, we use the sequence number in
    // the response to find the original request ID.
    mutex   sync.Mutex // protects seq, pending
    seq     uint64
    pending map[uint64]*json.RawMessage
}

//第二个

// NewServerCodec returns a new rpc.ServerCodec using JSON-RPC on conn.
func NewServerCodec(conn io.ReadWriteCloser, conn2 net.Conn) sRpcServer.ServerCodec {
    return &serverCodec{
        dec:     json.NewDecoder(conn),
        enc:     json.NewEncoder(conn),
        c:       conn,
        conn:    conn2,
        pending: make(map[uint64]*json.RawMessage),
    }
}

//第三个 添加一个方法
// 改造

func (c *serverCodec) GetNetConn() net.Conn {
    return c.conn
}

//第四个

// ServeConn runs the JSON-RPC server on a single connection.
// ServeConn blocks, serving the connection until the client hangs up.
// The caller typically invokes ServeConn in a go statement.
func ServeConn(conn net.Conn) {
    sRpcServer.ServeCodec(NewServerCodec(conn, conn))
}

改造rpc\server.go
//第一个

type ServerCodec interface {
    ReadRequestHeader(*Request) error
    ReadRequestBody(any) error
    WriteResponse(*Response, any) error
    // 改造
    GetNetConn() net.Conn

    // Close can be called multiple times and must be idempotent.
    Close() error
}

//第二个

type gobServerCodec struct {
    rwc    io.ReadWriteCloser
    dec    *gob.Decoder
    enc    *gob.Encoder
    encBuf *bufio.Writer
    closed bool
    conn   net.Conn //改造
}

//第三个 添加一个GetNetConn方法
// 改造

func (c *gobServerCodec) GetNetConn() net.Conn {
    return c.conn
}

//第四个

func ServeConn(conn io.ReadWriteCloser, conn2 net.Conn) {
    DefaultServer.ServeConn(conn, conn2)
}

第五个

func (server *Server) HandleHTTP(rpcPath, debugPath string) {
    http.Handle(rpcPath, server)
    //http.Handle(debugPath, debugHTTP{server}) 这里屏蔽掉它
}

….

..

最后也是最关键的 因为 我们要实现将接口由

func (t *TestInterface) Test(number int, rep *string) error {
    //这里获取ip...????
    return nil
}

改为

func (t *TestInterface) Test(conn net.Conn, number int, rep *string) error{
        //获取客户端Ip
    ClientIp := strings.Split(conn.RemoteAddr().String(), ":")[0]
    fmt.Println(ClientIp)
    return nil
}

添加一个net.Conn参数 所以必须改动 注册rpc接口 以及调用rpc接口这两个地方.

首先是call

func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
    if wg != nil {
        defer wg.Done()
    }
    mtype.Lock()
    mtype.numCalls++
    mtype.Unlock()
    function := mtype.method.Func
    // Invoke the method, providing a new value for the reply.
    //改动此处
    returnValues := function.Call([]reflect.Value{s.rcvr, reflect.ValueOf(codec.GetNetConn()), argv, replyv})
    // The return value for the method is an error.
    errInter := returnValues[0].Interface()
    errmsg := ""
    if errInter != nil {
        errmsg = errInter.(error).Error()
    }
    server.sendResponse(sending, req, replyv.Interface(), codec, errmsg)
    server.freeRequest(req)
}

//在改动register

func (server *Server) register(rcvr any, name string, useName bool) error {
    s := new(service)
    s.typ = reflect.TypeOf(rcvr)
    s.rcvr = reflect.ValueOf(rcvr)
    sname := name
    if !useName {
        sname = reflect.Indirect(s.rcvr).Type().Name()
    }
    if sname == "" {
        s := "rpc.Register: no service name for type " + s.typ.String()
        log.Print(s)
        return errors.New(s)
    }
    if !useName && !token.IsExported(sname) {
        s := "rpc.Register: type " + sname + " is not exported"
        log.Print(s)
        return errors.New(s)
    }
    s.name = sname

    // Install the methods
    s.method = suitableMethods(s.typ, logRegisterError)

    if len(s.method) == 0 {
        str := ""

        // To help the user, see if a pointer receiver would work.
        method := suitableMethods(reflect.PointerTo(s.typ), false)
        if len(method) != 0 {
            str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
        } else {
            str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
        }
        log.Print(str)
        return errors.New(str)
    }

    if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
        return errors.New("rpc: service already defined: " + sname)
    }
    return nil
}

//经过阅读代码发现主要由suitableMethods检查注册的接口类其中的函数是否是符合规定的.

// suitableMethods returns suitable Rpc methods of typ. It will log
// errors if logErr is true.
func suitableMethods(typ reflect.Type, logErr bool) map[string]*methodType {
    methods := make(map[string]*methodType)
    for m := 0; m < typ.NumMethod(); m++ {
        method := typ.Method(m)
        mtype := method.Type
        mname := method.Name
        // Method must be exported.
        if !method.IsExported() {
            continue
        }

        // Method needs three ins: receiver,net.conn, *args, *reply.
        //改造 因为 我们给所有rpc处理程序第一个参数改为了net.Conn
        if mtype.NumIn() != 4 {
            if logErr {
                log.Printf("rpc.Register: method %q has %d input parameters; needs exactly three\n", mname, mtype.NumIn())
            }
            continue
        }
        //改造 这是第一个参数net.Conn
        connType := mtype.In(1)
        //net.Conn 必须是interface
        if connType.Kind() != reflect.Interface {
            if logErr {
                log.Printf("rpc.Register: Conn type of method %q is not a interface: %q\n", mname, connType)
            }
            continue
        }
        //net.Conn必须是导出的
        if !isExportedOrBuiltinType(connType) {
            if logErr {
                log.Printf("rpc.Register: Conn type of method %q is not exported: %q\n", mname, connType)
            }
            continue
        }

        // First arg need not be a pointer.
        //改造 原先的第一个参数
        argType := mtype.In(2)
        if !isExportedOrBuiltinType(argType) {
            if logErr {
                log.Printf("rpc.Register: argument type of method %q is not exported: %q\n", mname, argType)
            }
            continue
        }
        // Second arg must be a pointer.
        //改造原先的第二个参数
        replyType := mtype.In(3)
        if replyType.Kind() != reflect.Pointer {
            if logErr {
                log.Printf("rpc.Register: reply type of method %q is not a pointer: %q\n", mname, replyType)
            }
            continue
        }
        // Reply type must be exported.
        if !isExportedOrBuiltinType(replyType) {
            if logErr {
                log.Printf("rpc.Register: reply type of method %q is not exported: %q\n", mname, replyType)
            }
            continue
        }
        // Method needs one out.
        if mtype.NumOut() != 1 {
            if logErr {
                log.Printf("rpc.Register: method %q has %d output parameters; needs exactly one\n", mname, mtype.NumOut())
            }
            continue
        }
        // The return type of the method must be error.
        if returnType := mtype.Out(0); returnType != typeOfError {
            if logErr {
                log.Printf("rpc.Register: return type of method %q is %q, must be error\n", mname, returnType)
            }
            continue
        }
        methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
    }
    return methods
}

//至此 改动完成 实验后的结果也很完美~

总结
该改还是得改啊..
当然还是希望后面go 能提供内置的api 这样就不用改了.

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注

此站点使用Akismet来减少垃圾评论。了解我们如何处理您的评论数据