采用 FieldMasks 来减少负载

在改进了 UpdateTasksRequest 消息之后,我们现在可以开始关注 FieldMasks,进一步减少有效载荷的大小,但这次我们将重点关注 ListTasksResponse

首先,让我们理解什么是 FieldMasks。它是包含字段路径列表的对象,告诉 Protobuf 包含哪些字段,同时隐式地告诉它哪些字段不应该包含。以下是一个示例。假设我们有一个 Task 消息,如下所示:

message Task {
  uint64 id = 1;
  string description = 2;
  bool done = 3;
  google.protobuf.Timestamp due_date = 4;
}

如果我们只想选择 iddone 字段,我们可以有一个简单的 FieldMask,如下所示:

mask {
  paths: "id"
  paths: "done"
}

然后,我们可以将这个掩码应用于 Task 实例,它只会保留提到的字段的值。这在执行类似 GET 的操作时非常有用,我们不想获取过多不必要的数据(即过度获取数据)。

我们的 TODO API 中有一个这样的用例:ListTasks。为什么呢?因为如果用户只想获取部分信息,他们目前无法做到这一点。选择部分数据对于某些功能非常有用,比如将任务从本地存储同步到后端。如果后端有 ID 1、2 和 3,而本地存储有 1、2、3、4 和 5,我们希望能够计算需要上传的任务的差异。为了做到这一点,我们只需要列出 ID,因为获取 descriptiondonedue_date 的值将是浪费的。

改进 ListTasksRequest

ListTasksResponse 是一种服务器流式的 API。我们发送一个请求,然后接收 0 个或多个响应。提到这一点很重要,因为发送 FieldMask 并不是免费的。我们仍然需要在网络上传输字节。然而,在我们的情况下,使用掩码是有意义的,因为我们可以一次发送它,并将它应用于服务器返回的所有元素。

首先,我们需要声明一个 FieldMask。为此,我们导入 field_mask.proto 并在 ListTasksRequest 中添加一个字段:

import "google/protobuf/field_mask.proto";
//...
message ListTasksRequest {
  google.protobuf.FieldMask mask = 1;
}

接下来,我们可以在服务器端应用这个掩码到我们发送的所有响应。这需要使用反射和一些样板代码。首先,我们需要在服务器中添加一个依赖项,以便能够操作切片,并特别访问 Contains 函数:

$ go get golang.org/x/exp/slices

然后,我们可以使用反射。我们将遍历给定消息的所有字段,如果该字段的名称不在掩码的路径中,我们将删除它的值:

以下代码是一个简单的实现,用于根据掩码过滤消息中的字段,但它对于我们的用例已经足够了。实际上,FieldMasks 还提供了更强大的功能,比如过滤映射、列表和子消息。不幸的是,Go 实现的 Protobuf 不像其他实现那样提供这些功能,因此我们需要依赖自己编写代码或使用社区项目。

import (
  "google.golang.org/protobuf/proto"
  "google.golang.org/protobuf/reflect/protoreflect"
  "google.golang.org/protobuf/types/known/fieldmaskpb"
  "golang.org/x/exp/slices"
)
//...
func Filter(msg proto.Message, mask *fieldmaskpb.FieldMask) {
  if mask == nil || len(mask.Paths) == 0 {
    return
  }
  rft := msg.ProtoReflect()
  rft.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
    if !slices.Contains(mask.Paths, string(fd.Name())) {
      rft.Clear(fd)
    }
    return true
  })
}

有了这个,我们现在基本上可以在 ListTasks 实现中使用 Filter 来过滤 ListTasksResponse 中将发送的 Task 对象:

func (s *server) ListTasks(req *pb.ListTasksRequest, stream pb.TodoService_ListTasksServer) error {
  return s.d.getTasks(func(t interface{}) error {
    task := t.(*pb.Task)
    Filter(task, req.Mask)
    overdue := task.DueDate != nil && !task.Done && task.DueDate.AsTime().Before(time.Now().UTC())
    err := stream.Send(&pb.ListTasksResponse{
      Task: task,
      Overdue: overdue,
    })
    return err
  })
}

请注意,在计算 Overdue 之前调用了 Filter。这是因为如果我们没有在 FieldMask 中包括 due_date 字段,我们假设用户不关心是否逾期。最终,Overdue 将为 false,不会被序列化,因此不会通过网络发送。

接下来,我们需要看看如何在客户端使用它。在这个示例中,printTasks 将只打印 ID。我们将接收 FieldMask 作为 printTasks 的参数,并将其添加到 ListTasksRequest 中:

func printTasks(c pb.TodoServiceClient, fm *fieldmaskpb.FieldMask) {
  req := &pb.ListTasksRequest{
    Mask: fm,
  }
  //...
}

最后,通过 fieldmaskpb.New,我们首先创建一个包含 id 路径的 FieldMask。这个函数会检查 id 是否是我们提供的消息中的有效路径。如果没有错误,我们可以将 Mask 字段设置到 ListTasksRequest 实例中:

func main() {
  //...
  fm, err := fieldmaskpb.New(&pb.Task{}, "id")
  if err != nil {
    log.Fatalf("unexpected error: %v", err)
  }
  //...
  fmt.Println("--------LIST-------")
  printTasks(c, fm)
  fmt.Println("-------------------")
  //...
}

如果我们运行它,应该会得到以下输出:

--------LIST-------
id:1 overdue:false
id:2 overdue:false
id:3 overdue:false
-------------------

请注意,overdue 仍然显示为 false,但在我们的例子中,这可以忽略不计,因为我们在 printTasks 函数中打印了 overdue,而 overdue 的默认值(布尔类型)是 false