diff options
Diffstat (limited to 'hw/virtio/vhost-user.c')
-rw-r--r-- | hw/virtio/vhost-user.c | 67 |
1 files changed, 65 insertions, 2 deletions
diff --git a/hw/virtio/vhost-user.c b/hw/virtio/vhost-user.c index b6757ebae3..1603d70bea 100644 --- a/hw/virtio/vhost-user.c +++ b/hw/virtio/vhost-user.c @@ -174,6 +174,7 @@ struct vhost_user { int slave_fd; NotifierWithReturn postcopy_notifier; struct PostCopyFD postcopy_fd; + uint64_t postcopy_client_bases[VHOST_MEMORY_MAX_NREGIONS]; /* True once we've entered postcopy_listen */ bool postcopy_listen; }; @@ -343,12 +344,15 @@ static int vhost_user_set_log_base(struct vhost_dev *dev, uint64_t base, static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev, struct vhost_memory *mem) { + struct vhost_user *u = dev->opaque; int fds[VHOST_MEMORY_MAX_NREGIONS]; int i, fd; size_t fd_num = 0; bool reply_supported = virtio_has_feature(dev->protocol_features, VHOST_USER_PROTOCOL_F_REPLY_ACK); - /* TODO: Add actual postcopy differences */ + VhostUserMsg msg_reply; + int region_i, msg_i; + VhostUserMsg msg = { .hdr.request = VHOST_USER_SET_MEM_TABLE, .hdr.flags = VHOST_USER_VERSION, @@ -395,6 +399,64 @@ static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev, return -1; } + if (vhost_user_read(dev, &msg_reply) < 0) { + return -1; + } + + if (msg_reply.hdr.request != VHOST_USER_SET_MEM_TABLE) { + error_report("%s: Received unexpected msg type." + "Expected %d received %d", __func__, + VHOST_USER_SET_MEM_TABLE, msg_reply.hdr.request); + return -1; + } + /* We're using the same structure, just reusing one of the + * fields, so it should be the same size. + */ + if (msg_reply.hdr.size != msg.hdr.size) { + error_report("%s: Unexpected size for postcopy reply " + "%d vs %d", __func__, msg_reply.hdr.size, msg.hdr.size); + return -1; + } + + memset(u->postcopy_client_bases, 0, + sizeof(uint64_t) * VHOST_MEMORY_MAX_NREGIONS); + + /* They're in the same order as the regions that were sent + * but some of the regions were skipped (above) if they + * didn't have fd's + */ + for (msg_i = 0, region_i = 0; + region_i < dev->mem->nregions; + region_i++) { + if (msg_i < fd_num && + msg_reply.payload.memory.regions[msg_i].guest_phys_addr == + dev->mem->regions[region_i].guest_phys_addr) { + u->postcopy_client_bases[region_i] = + msg_reply.payload.memory.regions[msg_i].userspace_addr; + trace_vhost_user_set_mem_table_postcopy( + msg_reply.payload.memory.regions[msg_i].userspace_addr, + msg.payload.memory.regions[msg_i].userspace_addr, + msg_i, region_i); + msg_i++; + } + } + if (msg_i != fd_num) { + error_report("%s: postcopy reply not fully consumed " + "%d vs %zd", + __func__, msg_i, fd_num); + return -1; + } + /* Now we've registered this with the postcopy code, we ack to the client, + * because now we're in the position to be able to deal with any faults + * it generates. + */ + /* TODO: Use this for failure cases as well with a bad value */ + msg.hdr.size = sizeof(msg.payload.u64); + msg.payload.u64 = 0; /* OK */ + if (vhost_user_write(dev, &msg, NULL, 0) < 0) { + return -1; + } + if (reply_supported) { return process_message_reply(dev, &msg); } @@ -411,7 +473,8 @@ static int vhost_user_set_mem_table(struct vhost_dev *dev, size_t fd_num = 0; bool do_postcopy = u->postcopy_listen && u->postcopy_fd.handler; bool reply_supported = virtio_has_feature(dev->protocol_features, - VHOST_USER_PROTOCOL_F_REPLY_ACK); + VHOST_USER_PROTOCOL_F_REPLY_ACK) && + !do_postcopy; if (do_postcopy) { /* Postcopy has enough differences that it's best done in it's own |