summaryrefslogtreecommitdiff
path: root/hw/virtio/vhost-user.c
diff options
context:
space:
mode:
Diffstat (limited to 'hw/virtio/vhost-user.c')
-rw-r--r--hw/virtio/vhost-user.c67
1 files changed, 65 insertions, 2 deletions
diff --git a/hw/virtio/vhost-user.c b/hw/virtio/vhost-user.c
index b6757ebae3..1603d70bea 100644
--- a/hw/virtio/vhost-user.c
+++ b/hw/virtio/vhost-user.c
@@ -174,6 +174,7 @@ struct vhost_user {
int slave_fd;
NotifierWithReturn postcopy_notifier;
struct PostCopyFD postcopy_fd;
+ uint64_t postcopy_client_bases[VHOST_MEMORY_MAX_NREGIONS];
/* True once we've entered postcopy_listen */
bool postcopy_listen;
};
@@ -343,12 +344,15 @@ static int vhost_user_set_log_base(struct vhost_dev *dev, uint64_t base,
static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev,
struct vhost_memory *mem)
{
+ struct vhost_user *u = dev->opaque;
int fds[VHOST_MEMORY_MAX_NREGIONS];
int i, fd;
size_t fd_num = 0;
bool reply_supported = virtio_has_feature(dev->protocol_features,
VHOST_USER_PROTOCOL_F_REPLY_ACK);
- /* TODO: Add actual postcopy differences */
+ VhostUserMsg msg_reply;
+ int region_i, msg_i;
+
VhostUserMsg msg = {
.hdr.request = VHOST_USER_SET_MEM_TABLE,
.hdr.flags = VHOST_USER_VERSION,
@@ -395,6 +399,64 @@ static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev,
return -1;
}
+ if (vhost_user_read(dev, &msg_reply) < 0) {
+ return -1;
+ }
+
+ if (msg_reply.hdr.request != VHOST_USER_SET_MEM_TABLE) {
+ error_report("%s: Received unexpected msg type."
+ "Expected %d received %d", __func__,
+ VHOST_USER_SET_MEM_TABLE, msg_reply.hdr.request);
+ return -1;
+ }
+ /* We're using the same structure, just reusing one of the
+ * fields, so it should be the same size.
+ */
+ if (msg_reply.hdr.size != msg.hdr.size) {
+ error_report("%s: Unexpected size for postcopy reply "
+ "%d vs %d", __func__, msg_reply.hdr.size, msg.hdr.size);
+ return -1;
+ }
+
+ memset(u->postcopy_client_bases, 0,
+ sizeof(uint64_t) * VHOST_MEMORY_MAX_NREGIONS);
+
+ /* They're in the same order as the regions that were sent
+ * but some of the regions were skipped (above) if they
+ * didn't have fd's
+ */
+ for (msg_i = 0, region_i = 0;
+ region_i < dev->mem->nregions;
+ region_i++) {
+ if (msg_i < fd_num &&
+ msg_reply.payload.memory.regions[msg_i].guest_phys_addr ==
+ dev->mem->regions[region_i].guest_phys_addr) {
+ u->postcopy_client_bases[region_i] =
+ msg_reply.payload.memory.regions[msg_i].userspace_addr;
+ trace_vhost_user_set_mem_table_postcopy(
+ msg_reply.payload.memory.regions[msg_i].userspace_addr,
+ msg.payload.memory.regions[msg_i].userspace_addr,
+ msg_i, region_i);
+ msg_i++;
+ }
+ }
+ if (msg_i != fd_num) {
+ error_report("%s: postcopy reply not fully consumed "
+ "%d vs %zd",
+ __func__, msg_i, fd_num);
+ return -1;
+ }
+ /* Now we've registered this with the postcopy code, we ack to the client,
+ * because now we're in the position to be able to deal with any faults
+ * it generates.
+ */
+ /* TODO: Use this for failure cases as well with a bad value */
+ msg.hdr.size = sizeof(msg.payload.u64);
+ msg.payload.u64 = 0; /* OK */
+ if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
+ return -1;
+ }
+
if (reply_supported) {
return process_message_reply(dev, &msg);
}
@@ -411,7 +473,8 @@ static int vhost_user_set_mem_table(struct vhost_dev *dev,
size_t fd_num = 0;
bool do_postcopy = u->postcopy_listen && u->postcopy_fd.handler;
bool reply_supported = virtio_has_feature(dev->protocol_features,
- VHOST_USER_PROTOCOL_F_REPLY_ACK);
+ VHOST_USER_PROTOCOL_F_REPLY_ACK) &&
+ !do_postcopy;
if (do_postcopy) {
/* Postcopy has enough differences that it's best done in it's own