summaryrefslogblamecommitdiffstats
path: root/contrib/libvhost-user/libvhost-user.c
blob: 35fa0c5e566c50f836dabec16b9832299e2dbe18 (plain) (tree)
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808


















































































                                                                     
                                    














































































                                                                            
                 



































































































































































































































































































































































































































































































































































































































































                                                                              

                         
































































































































































































































                                                                           




                                     

























































































                                                                         
    

                                       




                                     
                                                     
                     















































                                                                          

                                     



















































































































                                                                               

                                     
























































































































































                                                                    

                                     



























                                                                             

                                     





















                                                                                
/*
 * Vhost User library
 *
 * Copyright IBM, Corp. 2007
 * Copyright (c) 2016 Red Hat, Inc.
 *
 * Authors:
 *  Anthony Liguori <aliguori@us.ibm.com>
 *  Marc-André Lureau <mlureau@redhat.com>
 *  Victor Kaplansky <victork@redhat.com>
 *
 * This work is licensed under the terms of the GNU GPL, version 2 or
 * later.  See the COPYING file in the top-level directory.
 */

#include <qemu/osdep.h>
#include <sys/eventfd.h>
#include <linux/vhost.h>

#include "qemu/atomic.h"

#include "libvhost-user.h"

#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)

/* The version of the protocol we support */
#define VHOST_USER_VERSION 1
#define LIBVHOST_USER_DEBUG 0

#define DPRINT(...)                             \
    do {                                        \
        if (LIBVHOST_USER_DEBUG) {              \
            fprintf(stderr, __VA_ARGS__);        \
        }                                       \
    } while (0)

static const char *
vu_request_to_string(int req)
{
#define REQ(req) [req] = #req
    static const char *vu_request_str[] = {
        REQ(VHOST_USER_NONE),
        REQ(VHOST_USER_GET_FEATURES),
        REQ(VHOST_USER_SET_FEATURES),
        REQ(VHOST_USER_NONE),
        REQ(VHOST_USER_GET_FEATURES),
        REQ(VHOST_USER_SET_FEATURES),
        REQ(VHOST_USER_SET_OWNER),
        REQ(VHOST_USER_RESET_OWNER),
        REQ(VHOST_USER_SET_MEM_TABLE),
        REQ(VHOST_USER_SET_LOG_BASE),
        REQ(VHOST_USER_SET_LOG_FD),
        REQ(VHOST_USER_SET_VRING_NUM),
        REQ(VHOST_USER_SET_VRING_ADDR),
        REQ(VHOST_USER_SET_VRING_BASE),
        REQ(VHOST_USER_GET_VRING_BASE),
        REQ(VHOST_USER_SET_VRING_KICK),
        REQ(VHOST_USER_SET_VRING_CALL),
        REQ(VHOST_USER_SET_VRING_ERR),
        REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
        REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
        REQ(VHOST_USER_GET_QUEUE_NUM),
        REQ(VHOST_USER_SET_VRING_ENABLE),
        REQ(VHOST_USER_SEND_RARP),
        REQ(VHOST_USER_INPUT_GET_CONFIG),
        REQ(VHOST_USER_MAX),
    };
#undef REQ

    if (req < VHOST_USER_MAX) {
        return vu_request_str[req];
    } else {
        return "unknown";
    }
}

static void
vu_panic(VuDev *dev, const char *msg, ...)
{
    char *buf = NULL;
    va_list ap;

    va_start(ap, msg);
    buf = g_strdup_vprintf(msg, ap);
    va_end(ap);

    dev->broken = true;
    dev->panic(dev, buf);
    free(buf);

    /* FIXME: find a way to call virtio_error? */
}

/* Translate guest physical address to our virtual address.  */
void *
vu_gpa_to_va(VuDev *dev, uint64_t guest_addr)
{
    int i;

    /* Find matching memory region.  */
    for (i = 0; i < dev->nregions; i++) {
        VuDevRegion *r = &dev->regions[i];

        if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
            return (void *)(uintptr_t)
                guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
        }
    }

    return NULL;
}

/* Translate qemu virtual address to our virtual address.  */
static void *
qva_to_va(VuDev *dev, uint64_t qemu_addr)
{
    int i;

    /* Find matching memory region.  */
    for (i = 0; i < dev->nregions; i++) {
        VuDevRegion *r = &dev->regions[i];

        if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
            return (void *)(uintptr_t)
                qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
        }
    }

    return NULL;
}

static void
vmsg_close_fds(VhostUserMsg *vmsg)
{
    int i;

    for (i = 0; i < vmsg->fd_num; i++) {
        close(vmsg->fds[i]);
    }
}

static bool
vu_message_read(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
    char control[CMSG_SPACE(VHOST_MEMORY_MAX_NREGIONS * sizeof(int))] = { };
    struct iovec iov = {
        .iov_base = (char *)vmsg,
        .iov_len = VHOST_USER_HDR_SIZE,
    };
    struct msghdr msg = {
        .msg_iov = &iov,
        .msg_iovlen = 1,
        .msg_control = control,
        .msg_controllen = sizeof(control),
    };
    size_t fd_size;
    struct cmsghdr *cmsg;
    int rc;

    do {
        rc = recvmsg(conn_fd, &msg, 0);
    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));

    if (rc < 0) {
        vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
        return false;
    }

    vmsg->fd_num = 0;
    for (cmsg = CMSG_FIRSTHDR(&msg);
         cmsg != NULL;
         cmsg = CMSG_NXTHDR(&msg, cmsg))
    {
        if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
            fd_size = cmsg->cmsg_len - CMSG_LEN(0);
            vmsg->fd_num = fd_size / sizeof(int);
            memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
            break;
        }
    }

    if (vmsg->size > sizeof(vmsg->payload)) {
        vu_panic(dev,
                 "Error: too big message request: %d, size: vmsg->size: %u, "
                 "while sizeof(vmsg->payload) = %zu\n",
                 vmsg->request, vmsg->size, sizeof(vmsg->payload));
        goto fail;
    }

    if (vmsg->size) {
        do {
            rc = read(conn_fd, &vmsg->payload, vmsg->size);
        } while (rc < 0 && (errno == EINTR || errno == EAGAIN));

        if (rc <= 0) {
            vu_panic(dev, "Error while reading: %s", strerror(errno));
            goto fail;
        }

        assert(rc == vmsg->size);
    }

    return true;

fail:
    vmsg_close_fds(vmsg);

    return false;
}

static bool
vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
{
    int rc;
    uint8_t *p = (uint8_t *)vmsg;

    /* Set the version in the flags when sending the reply */
    vmsg->flags &= ~VHOST_USER_VERSION_MASK;
    vmsg->flags |= VHOST_USER_VERSION;
    vmsg->flags |= VHOST_USER_REPLY_MASK;

    do {
        rc = write(conn_fd, p, VHOST_USER_HDR_SIZE);
    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));

    do {
        if (vmsg->data) {
            rc = write(conn_fd, vmsg->data, vmsg->size);
        } else {
            rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
        }
    } while (rc < 0 && (errno == EINTR || errno == EAGAIN));

    if (rc <= 0) {
        vu_panic(dev, "Error while writing: %s", strerror(errno));
        return false;
    }

    return true;
}

/* Kick the log_call_fd if required. */
static void
vu_log_kick(VuDev *dev)
{
    if (dev->log_call_fd != -1) {
        DPRINT("Kicking the QEMU's log...\n");
        if (eventfd_write(dev->log_call_fd, 1) < 0) {
            vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
        }
    }
}

static void
vu_log_page(uint8_t *log_table, uint64_t page)
{
    DPRINT("Logged dirty guest page: %"PRId64"\n", page);
    atomic_or(&log_table[page / 8], 1 << (page % 8));
}

static void
vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
{
    uint64_t page;

    if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
        !dev->log_table || !length) {
        return;
    }

    assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));

    page = address / VHOST_LOG_PAGE;
    while (page * VHOST_LOG_PAGE < address + length) {
        vu_log_page(dev->log_table, page);
        page += VHOST_LOG_PAGE;
    }

    vu_log_kick(dev);
}

static void
vu_kick_cb(VuDev *dev, int condition, void *data)
{
    int index = (intptr_t)data;
    VuVirtq *vq = &dev->vq[index];
    int sock = vq->kick_fd;
    eventfd_t kick_data;
    ssize_t rc;

    rc = eventfd_read(sock, &kick_data);
    if (rc == -1) {
        vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
        dev->remove_watch(dev, dev->vq[index].kick_fd);
    } else {
        DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
               kick_data, vq->handler, index);
        if (vq->handler) {
            vq->handler(dev, index);
        }
    }
}

static bool
vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    vmsg->payload.u64 =
        1ULL << VHOST_F_LOG_ALL |
        1ULL << VHOST_USER_F_PROTOCOL_FEATURES;

    if (dev->iface->get_features) {
        vmsg->payload.u64 |= dev->iface->get_features(dev);
    }

    vmsg->size = sizeof(vmsg->payload.u64);

    DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);

    return true;
}

static void
vu_set_enable_all_rings(VuDev *dev, bool enabled)
{
    int i;

    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
        dev->vq[i].enable = enabled;
    }
}

static bool
vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);

    dev->features = vmsg->payload.u64;

    if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
        vu_set_enable_all_rings(dev, true);
    }

    if (dev->iface->set_features) {
        dev->iface->set_features(dev, dev->features);
    }

    return false;
}

static bool
vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    return false;
}

static void
vu_close_log(VuDev *dev)
{
    if (dev->log_table) {
        if (munmap(dev->log_table, dev->log_size) != 0) {
            perror("close log munmap() error");
        }

        dev->log_table = NULL;
    }
    if (dev->log_call_fd != -1) {
        close(dev->log_call_fd);
        dev->log_call_fd = -1;
    }
}

static bool
vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    vu_set_enable_all_rings(dev, false);

    return false;
}

static bool
vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    int i;
    VhostUserMemory *memory = &vmsg->payload.memory;
    dev->nregions = memory->nregions;

    DPRINT("Nregions: %d\n", memory->nregions);
    for (i = 0; i < dev->nregions; i++) {
        void *mmap_addr;
        VhostUserMemoryRegion *msg_region = &memory->regions[i];
        VuDevRegion *dev_region = &dev->regions[i];

        DPRINT("Region %d\n", i);
        DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
               msg_region->guest_phys_addr);
        DPRINT("    memory_size:     0x%016"PRIx64"\n",
               msg_region->memory_size);
        DPRINT("    userspace_addr   0x%016"PRIx64"\n",
               msg_region->userspace_addr);
        DPRINT("    mmap_offset      0x%016"PRIx64"\n",
               msg_region->mmap_offset);

        dev_region->gpa = msg_region->guest_phys_addr;
        dev_region->size = msg_region->memory_size;
        dev_region->qva = msg_region->userspace_addr;
        dev_region->mmap_offset = msg_region->mmap_offset;

        /* We don't use offset argument of mmap() since the
         * mapped address has to be page aligned, and we use huge
         * pages.  */
        mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
                         PROT_READ | PROT_WRITE, MAP_SHARED,
                         vmsg->fds[i], 0);

        if (mmap_addr == MAP_FAILED) {
            vu_panic(dev, "region mmap error: %s", strerror(errno));
        } else {
            dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
            DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
                   dev_region->mmap_addr);
        }

        close(vmsg->fds[i]);
    }

    return false;
}

static bool
vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    int fd;
    uint64_t log_mmap_size, log_mmap_offset;
    void *rc;

    if (vmsg->fd_num != 1 ||
        vmsg->size != sizeof(vmsg->payload.log)) {
        vu_panic(dev, "Invalid log_base message");
        return true;
    }

    fd = vmsg->fds[0];
    log_mmap_offset = vmsg->payload.log.mmap_offset;
    log_mmap_size = vmsg->payload.log.mmap_size;
    DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
    DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);

    rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
              log_mmap_offset);
    if (rc == MAP_FAILED) {
        perror("log mmap error");
    }
    dev->log_table = rc;
    dev->log_size = log_mmap_size;

    vmsg->size = sizeof(vmsg->payload.u64);

    return true;
}

static bool
vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    if (vmsg->fd_num != 1) {
        vu_panic(dev, "Invalid log_fd message");
        return false;
    }

    if (dev->log_call_fd != -1) {
        close(dev->log_call_fd);
    }
    dev->log_call_fd = vmsg->fds[0];
    DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);

    return false;
}

static bool
vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    unsigned int index = vmsg->payload.state.index;
    unsigned int num = vmsg->payload.state.num;

    DPRINT("State.index: %d\n", index);
    DPRINT("State.num:   %d\n", num);
    dev->vq[index].vring.num = num;

    return false;
}

static bool
vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    struct vhost_vring_addr *vra = &vmsg->payload.addr;
    unsigned int index = vra->index;
    VuVirtq *vq = &dev->vq[index];

    DPRINT("vhost_vring_addr:\n");
    DPRINT("    index:  %d\n", vra->index);
    DPRINT("    flags:  %d\n", vra->flags);
    DPRINT("    desc_user_addr:   0x%016llx\n", vra->desc_user_addr);
    DPRINT("    used_user_addr:   0x%016llx\n", vra->used_user_addr);
    DPRINT("    avail_user_addr:  0x%016llx\n", vra->avail_user_addr);
    DPRINT("    log_guest_addr:   0x%016llx\n", vra->log_guest_addr);

    vq->vring.flags = vra->flags;
    vq->vring.desc = qva_to_va(dev, vra->desc_user_addr);
    vq->vring.used = qva_to_va(dev, vra->used_user_addr);
    vq->vring.avail = qva_to_va(dev, vra->avail_user_addr);
    vq->vring.log_guest_addr = vra->log_guest_addr;

    DPRINT("Setting virtq addresses:\n");
    DPRINT("    vring_desc  at %p\n", vq->vring.desc);
    DPRINT("    vring_used  at %p\n", vq->vring.used);
    DPRINT("    vring_avail at %p\n", vq->vring.avail);

    if (!(vq->vring.desc && vq->vring.used && vq->vring.avail)) {
        vu_panic(dev, "Invalid vring_addr message");
        return false;
    }

    vq->used_idx = vq->vring.used->idx;

    return false;
}

static bool
vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    unsigned int index = vmsg->payload.state.index;
    unsigned int num = vmsg->payload.state.num;

    DPRINT("State.index: %d\n", index);
    DPRINT("State.num:   %d\n", num);
    dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;

    return false;
}

static bool
vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    unsigned int index = vmsg->payload.state.index;

    DPRINT("State.index: %d\n", index);
    vmsg->payload.state.num = dev->vq[index].last_avail_idx;
    vmsg->size = sizeof(vmsg->payload.state);

    dev->vq[index].started = false;
    if (dev->iface->queue_set_started) {
        dev->iface->queue_set_started(dev, index, false);
    }

    if (dev->vq[index].call_fd != -1) {
        close(dev->vq[index].call_fd);
        dev->vq[index].call_fd = -1;
    }
    if (dev->vq[index].kick_fd != -1) {
        dev->remove_watch(dev, dev->vq[index].kick_fd);
        close(dev->vq[index].kick_fd);
        dev->vq[index].kick_fd = -1;
    }

    return true;
}

static bool
vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
{
    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;

    if (index >= VHOST_MAX_NR_VIRTQUEUE) {
        vmsg_close_fds(vmsg);
        vu_panic(dev, "Invalid queue index: %u", index);
        return false;
    }

    if (vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK ||
        vmsg->fd_num != 1) {
        vmsg_close_fds(vmsg);
        vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
        return false;
    }

    return true;
}

static bool
vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;

    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);

    if (!vu_check_queue_msg_file(dev, vmsg)) {
        return false;
    }

    if (dev->vq[index].kick_fd != -1) {
        dev->remove_watch(dev, dev->vq[index].kick_fd);
        close(dev->vq[index].kick_fd);
        dev->vq[index].kick_fd = -1;
    }

    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
        dev->vq[index].kick_fd = vmsg->fds[0];
        DPRINT("Got kick_fd: %d for vq: %d\n", vmsg->fds[0], index);
    }

    dev->vq[index].started = true;
    if (dev->iface->queue_set_started) {
        dev->iface->queue_set_started(dev, index, true);
    }

    if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
        dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
                       vu_kick_cb, (void *)(long)index);

        DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
               dev->vq[index].kick_fd, index);
    }

    return false;
}

void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
                          vu_queue_handler_cb handler)
{
    int qidx = vq - dev->vq;

    vq->handler = handler;
    if (vq->kick_fd >= 0) {
        if (handler) {
            dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
                           vu_kick_cb, (void *)(long)qidx);
        } else {
            dev->remove_watch(dev, vq->kick_fd);
        }
    }
}

static bool
vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;

    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);

    if (!vu_check_queue_msg_file(dev, vmsg)) {
        return false;
    }

    if (dev->vq[index].call_fd != -1) {
        close(dev->vq[index].call_fd);
        dev->vq[index].call_fd = -1;
    }

    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
        dev->vq[index].call_fd = vmsg->fds[0];
    }

    DPRINT("Got call_fd: %d for vq: %d\n", vmsg->fds[0], index);

    return false;
}

static bool
vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;

    DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);

    if (!vu_check_queue_msg_file(dev, vmsg)) {
        return false;
    }

    if (dev->vq[index].err_fd != -1) {
        close(dev->vq[index].err_fd);
        dev->vq[index].err_fd = -1;
    }

    if (!(vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK)) {
        dev->vq[index].err_fd = vmsg->fds[0];
    }

    return false;
}

static bool
vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD;

    if (dev->iface->get_protocol_features) {
        features |= dev->iface->get_protocol_features(dev);
    }

    vmsg->payload.u64 = features;
    vmsg->size = sizeof(vmsg->payload.u64);

    return true;
}

static bool
vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    uint64_t features = vmsg->payload.u64;

    DPRINT("u64: 0x%016"PRIx64"\n", features);

    dev->protocol_features = vmsg->payload.u64;

    if (dev->iface->set_protocol_features) {
        dev->iface->set_protocol_features(dev, features);
    }

    return false;
}

static bool
vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    DPRINT("Function %s() not implemented yet.\n", __func__);
    return false;
}

static bool
vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
{
    unsigned int index = vmsg->payload.state.index;
    unsigned int enable = vmsg->payload.state.num;

    DPRINT("State.index: %d\n", index);
    DPRINT("State.enable:   %d\n", enable);

    if (index >= VHOST_MAX_NR_VIRTQUEUE) {
        vu_panic(dev, "Invalid vring_enable index: %u", index);
        return false;
    }

    dev->vq[index].enable = enable;
    return false;
}

static bool
vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
{
    int do_reply = 0;

    /* Print out generic part of the request. */
    DPRINT("================ Vhost user message ================\n");
    DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
           vmsg->request);
    DPRINT("Flags:   0x%x\n", vmsg->flags);
    DPRINT("Size:    %d\n", vmsg->size);

    if (vmsg->fd_num) {
        int i;
        DPRINT("Fds:");
        for (i = 0; i < vmsg->fd_num; i++) {
            DPRINT(" %d", vmsg->fds[i]);
        }
        DPRINT("\n");
    }

    if (dev->iface->process_msg &&
        dev->iface->process_msg(dev, vmsg, &do_reply)) {
        return do_reply;
    }

    switch (vmsg->request) {
    case VHOST_USER_GET_FEATURES:
        return vu_get_features_exec(dev, vmsg);
    case VHOST_USER_SET_FEATURES:
        return vu_set_features_exec(dev, vmsg);
    case VHOST_USER_GET_PROTOCOL_FEATURES:
        return vu_get_protocol_features_exec(dev, vmsg);
    case VHOST_USER_SET_PROTOCOL_FEATURES:
        return vu_set_protocol_features_exec(dev, vmsg);
    case VHOST_USER_SET_OWNER:
        return vu_set_owner_exec(dev, vmsg);
    case VHOST_USER_RESET_OWNER:
        return vu_reset_device_exec(dev, vmsg);
    case VHOST_USER_SET_MEM_TABLE:
        return vu_set_mem_table_exec(dev, vmsg);
    case VHOST_USER_SET_LOG_BASE:
        return vu_set_log_base_exec(dev, vmsg);
    case VHOST_USER_SET_LOG_FD:
        return vu_set_log_fd_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_NUM:
        return vu_set_vring_num_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_ADDR:
        return vu_set_vring_addr_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_BASE:
        return vu_set_vring_base_exec(dev, vmsg);
    case VHOST_USER_GET_VRING_BASE:
        return vu_get_vring_base_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_KICK:
        return vu_set_vring_kick_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_CALL:
        return vu_set_vring_call_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_ERR:
        return vu_set_vring_err_exec(dev, vmsg);
    case VHOST_USER_GET_QUEUE_NUM:
        return vu_get_queue_num_exec(dev, vmsg);
    case VHOST_USER_SET_VRING_ENABLE:
        return vu_set_vring_enable_exec(dev, vmsg);
    case VHOST_USER_NONE:
        break;
    default:
        vmsg_close_fds(vmsg);
        vu_panic(dev, "Unhandled request: %d", vmsg->request);
    }

    return false;
}

bool
vu_dispatch(VuDev *dev)
{
    VhostUserMsg vmsg = { 0, };
    int reply_requested;
    bool success = false;

    if (!vu_message_read(dev, dev->sock, &vmsg)) {
        goto end;
    }

    reply_requested = vu_process_message(dev, &vmsg);
    if (!reply_requested) {
        success = true;
        goto end;
    }

    if (!vu_message_write(dev, dev->sock, &vmsg)) {
        goto end;
    }

    success = true;

end:
    g_free(vmsg.data);
    return success;
}

void
vu_deinit(VuDev *dev)
{
    int i;

    for (i = 0; i < dev->nregions; i++) {
        VuDevRegion *r = &dev->regions[i];
        void *m = (void *) (uintptr_t) r->mmap_addr;
        if (m != MAP_FAILED) {
            munmap(m, r->size + r->mmap_offset);
        }
    }
    dev->nregions = 0;

    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
        VuVirtq *vq = &dev->vq[i];

        if (vq->call_fd != -1) {
            close(vq->call_fd);
            vq->call_fd = -1;
        }

        if (vq->kick_fd != -1) {
            close(vq->kick_fd);
            vq->kick_fd = -1;
        }

        if (vq->err_fd != -1) {
            close(vq->err_fd);
            vq->err_fd = -1;
        }
    }


    vu_close_log(dev);

    if (dev->sock != -1) {
        close(dev->sock);
    }
}

void
vu_init(VuDev *dev,
        int socket,
        vu_panic_cb panic,
        vu_set_watch_cb set_watch,
        vu_remove_watch_cb remove_watch,
        const VuDevIface *iface)
{
    int i;

    assert(socket >= 0);
    assert(set_watch);
    assert(remove_watch);
    assert(iface);
    assert(panic);

    memset(dev, 0, sizeof(*dev));

    dev->sock = socket;
    dev->panic = panic;
    dev->set_watch = set_watch;
    dev->remove_watch = remove_watch;
    dev->iface = iface;
    dev->log_call_fd = -1;
    for (i = 0; i < VHOST_MAX_NR_VIRTQUEUE; i++) {
        dev->vq[i] = (VuVirtq) {
            .call_fd = -1, .kick_fd = -1, .err_fd = -1,
            .notification = true,
        };
    }
}

VuVirtq *
vu_get_queue(VuDev *dev, int qidx)
{
    assert(qidx < VHOST_MAX_NR_VIRTQUEUE);
    return &dev->vq[qidx];
}

bool
vu_queue_enabled(VuDev *dev, VuVirtq *vq)
{
    return vq->enable;
}

static inline uint16_t
vring_avail_flags(VuVirtq *vq)
{
    return vq->vring.avail->flags;
}

static inline uint16_t
vring_avail_idx(VuVirtq *vq)
{
    vq->shadow_avail_idx = vq->vring.avail->idx;

    return vq->shadow_avail_idx;
}

static inline uint16_t
vring_avail_ring(VuVirtq *vq, int i)
{
    return vq->vring.avail->ring[i];
}

static inline uint16_t
vring_get_used_event(VuVirtq *vq)
{
    return vring_avail_ring(vq, vq->vring.num);
}

static int
virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
{
    uint16_t num_heads = vring_avail_idx(vq) - idx;

    /* Check it isn't doing very strange things with descriptor numbers. */
    if (num_heads > vq->vring.num) {
        vu_panic(dev, "Guest moved used index from %u to %u",
                 idx, vq->shadow_avail_idx);
        return -1;
    }
    if (num_heads) {
        /* On success, callers read a descriptor at vq->last_avail_idx.
         * Make sure descriptor read does not bypass avail index read. */
        smp_rmb();
    }

    return num_heads;
}

static bool
virtqueue_get_head(VuDev *dev, VuVirtq *vq,
                   unsigned int idx, unsigned int *head)
{
    /* Grab the next descriptor number they're advertising, and increment
     * the index we've seen. */
    *head = vring_avail_ring(vq, idx % vq->vring.num);

    /* If their number is silly, that's a fatal mistake. */
    if (*head >= vq->vring.num) {
        vu_panic(dev, "Guest says index %u is available", head);
        return false;
    }

    return true;
}

enum {
    VIRTQUEUE_READ_DESC_ERROR = -1,
    VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
    VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
};

static int
virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
                         int i, unsigned int max, unsigned int *next)
{
    /* If this descriptor says it doesn't chain, we're done. */
    if (!(desc[i].flags & VRING_DESC_F_NEXT)) {
        return VIRTQUEUE_READ_DESC_DONE;
    }

    /* Check they're not leading us off end of descriptors. */
    *next = desc[i].next;
    /* Make sure compiler knows to grab that: we don't want it changing! */
    smp_wmb();

    if (*next >= max) {
        vu_panic(dev, "Desc next is %u", next);
        return VIRTQUEUE_READ_DESC_ERROR;
    }

    return VIRTQUEUE_READ_DESC_MORE;
}

void
vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
                         unsigned int *out_bytes,
                         unsigned max_in_bytes, unsigned max_out_bytes)
{
    unsigned int idx;
    unsigned int total_bufs, in_total, out_total;
    int rc;

    idx = vq->last_avail_idx;

    total_bufs = in_total = out_total = 0;
    if (unlikely(dev->broken) ||
        unlikely(!vq->vring.avail)) {
        goto done;
    }

    while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
        unsigned int max, num_bufs, indirect = 0;
        struct vring_desc *desc;
        unsigned int i;

        max = vq->vring.num;
        num_bufs = total_bufs;
        if (!virtqueue_get_head(dev, vq, idx++, &i)) {
            goto err;
        }
        desc = vq->vring.desc;

        if (desc[i].flags & VRING_DESC_F_INDIRECT) {
            if (desc[i].len % sizeof(struct vring_desc)) {
                vu_panic(dev, "Invalid size for indirect buffer table");
                goto err;
            }

            /* If we've got too many, that implies a descriptor loop. */
            if (num_bufs >= max) {
                vu_panic(dev, "Looped descriptor");
                goto err;
            }

            /* loop over the indirect descriptor table */
            indirect = 1;
            max = desc[i].len / sizeof(struct vring_desc);
            desc = vu_gpa_to_va(dev, desc[i].addr);
            num_bufs = i = 0;
        }

        do {
            /* If we've got too many, that implies a descriptor loop. */
            if (++num_bufs > max) {
                vu_panic(dev, "Looped descriptor");
                goto err;
            }

            if (desc[i].flags & VRING_DESC_F_WRITE) {
                in_total += desc[i].len;
            } else {
                out_total += desc[i].len;
            }
            if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
                goto done;
            }
            rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
        } while (rc == VIRTQUEUE_READ_DESC_MORE);

        if (rc == VIRTQUEUE_READ_DESC_ERROR) {
            goto err;
        }

        if (!indirect) {
            total_bufs = num_bufs;
        } else {
            total_bufs++;
        }
    }
    if (rc < 0) {
        goto err;
    }
done:
    if (in_bytes) {
        *in_bytes = in_total;
    }
    if (out_bytes) {
        *out_bytes = out_total;
    }
    return;

err:
    in_total = out_total = 0;
    goto done;
}

bool
vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
                     unsigned int out_bytes)
{
    unsigned int in_total, out_total;

    vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
                             in_bytes, out_bytes);

    return in_bytes <= in_total && out_bytes <= out_total;
}

/* Fetch avail_idx from VQ memory only when we really need to know if
 * guest has added some buffers. */
bool
vu_queue_empty(VuDev *dev, VuVirtq *vq)
{
    if (unlikely(dev->broken) ||
        unlikely(!vq->vring.avail)) {
        return true;
    }

    if (vq->shadow_avail_idx != vq->last_avail_idx) {
        return false;
    }

    return vring_avail_idx(vq) == vq->last_avail_idx;
}

static inline
bool has_feature(uint64_t features, unsigned int fbit)
{
    assert(fbit < 64);
    return !!(features & (1ULL << fbit));
}

static inline
bool vu_has_feature(VuDev *dev,
                    unsigned int fbit)
{
    return has_feature(dev->features, fbit);
}

static bool
vring_notify(VuDev *dev, VuVirtq *vq)
{
    uint16_t old, new;
    bool v;

    /* We need to expose used array entries before checking used event. */
    smp_mb();

    /* Always notify when queue is empty (when feature acknowledge) */
    if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
        !vq->inuse && vu_queue_empty(dev, vq)) {
        return true;
    }

    if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
        return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
    }

    v = vq->signalled_used_valid;
    vq->signalled_used_valid = true;
    old = vq->signalled_used;
    new = vq->signalled_used = vq->used_idx;
    return !v || vring_need_event(vring_get_used_event(vq), new, old);
}

void
vu_queue_notify(VuDev *dev, VuVirtq *vq)
{
    if (unlikely(dev->broken) ||
        unlikely(!vq->vring.avail)) {
        return;
    }

    if (!vring_notify(dev, vq)) {
        DPRINT("skipped notify...\n");
        return;
    }

    if (eventfd_write(vq->call_fd, 1) < 0) {
        vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
    }
}

static inline void
vring_used_flags_set_bit(VuVirtq *vq, int mask)
{
    uint16_t *flags;

    flags = (uint16_t *)((char*)vq->vring.used +
                         offsetof(struct vring_used, flags));
    *flags |= mask;
}

static inline void
vring_used_flags_unset_bit(VuVirtq *vq, int mask)
{
    uint16_t *flags;

    flags = (uint16_t *)((char*)vq->vring.used +
                         offsetof(struct vring_used, flags));
    *flags &= ~mask;
}

static inline void
vring_set_avail_event(VuVirtq *vq, uint16_t val)
{
    if (!vq->notification) {
        return;
    }

    *((uint16_t *) &vq->vring.used->ring[vq->vring.num]) = val;
}

void
vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
{
    vq->notification = enable;
    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
        vring_set_avail_event(vq, vring_avail_idx(vq));
    } else if (enable) {
        vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
    } else {
        vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
    }
    if (enable) {
        /* Expose avail event/used flags before caller checks the avail idx. */
        smp_mb();
    }
}

static void
virtqueue_map_desc(VuDev *dev,
                   unsigned int *p_num_sg, struct iovec *iov,
                   unsigned int max_num_sg, bool is_write,
                   uint64_t pa, size_t sz)
{
    unsigned num_sg = *p_num_sg;

    assert(num_sg <= max_num_sg);

    if (!sz) {
        vu_panic(dev, "virtio: zero sized buffers are not allowed");
        return;
    }

    iov[num_sg].iov_base = vu_gpa_to_va(dev, pa);
    iov[num_sg].iov_len = sz;
    num_sg++;

    *p_num_sg = num_sg;
}

/* Round number down to multiple */
#define ALIGN_DOWN(n, m) ((n) / (m) * (m))

/* Round number up to multiple */
#define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))

static void *
virtqueue_alloc_element(size_t sz,
                                     unsigned out_num, unsigned in_num)
{
    VuVirtqElement *elem;
    size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
    size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
    size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);

    assert(sz >= sizeof(VuVirtqElement));
    elem = malloc(out_sg_end);
    elem->out_num = out_num;
    elem->in_num = in_num;
    elem->in_sg = (void *)elem + in_sg_ofs;
    elem->out_sg = (void *)elem + out_sg_ofs;
    return elem;
}

void *
vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
{
    unsigned int i, head, max;
    VuVirtqElement *elem;
    unsigned out_num, in_num;
    struct iovec iov[VIRTQUEUE_MAX_SIZE];
    struct vring_desc *desc;
    int rc;

    if (unlikely(dev->broken) ||
        unlikely(!vq->vring.avail)) {
        return NULL;
    }

    if (vu_queue_empty(dev, vq)) {
        return NULL;
    }
    /* Needed after virtio_queue_empty(), see comment in
     * virtqueue_num_heads(). */
    smp_rmb();

    /* When we start there are none of either input nor output. */
    out_num = in_num = 0;

    max = vq->vring.num;
    if (vq->inuse >= vq->vring.num) {
        vu_panic(dev, "Virtqueue size exceeded");
        return NULL;
    }

    if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
        return NULL;
    }

    if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
        vring_set_avail_event(vq, vq->last_avail_idx);
    }

    i = head;
    desc = vq->vring.desc;
    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
        if (desc[i].len % sizeof(struct vring_desc)) {
            vu_panic(dev, "Invalid size for indirect buffer table");
        }

        /* loop over the indirect descriptor table */
        max = desc[i].len / sizeof(struct vring_desc);
        desc = vu_gpa_to_va(dev, desc[i].addr);
        i = 0;
    }

    /* Collect all the descriptors */
    do {
        if (desc[i].flags & VRING_DESC_F_WRITE) {
            virtqueue_map_desc(dev, &in_num, iov + out_num,
                               VIRTQUEUE_MAX_SIZE - out_num, true,
                               desc[i].addr, desc[i].len);
        } else {
            if (in_num) {
                vu_panic(dev, "Incorrect order for descriptors");
                return NULL;
            }
            virtqueue_map_desc(dev, &out_num, iov,
                               VIRTQUEUE_MAX_SIZE, false,
                               desc[i].addr, desc[i].len);
        }

        /* If we've got too many, that implies a descriptor loop. */
        if ((in_num + out_num) > max) {
            vu_panic(dev, "Looped descriptor");
        }
        rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
    } while (rc == VIRTQUEUE_READ_DESC_MORE);

    if (rc == VIRTQUEUE_READ_DESC_ERROR) {
        return NULL;
    }

    /* Now copy what we have collected and mapped */
    elem = virtqueue_alloc_element(sz, out_num, in_num);
    elem->index = head;
    for (i = 0; i < out_num; i++) {
        elem->out_sg[i] = iov[i];
    }
    for (i = 0; i < in_num; i++) {
        elem->in_sg[i] = iov[out_num + i];
    }

    vq->inuse++;

    return elem;
}

bool
vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
{
    if (num > vq->inuse) {
        return false;
    }
    vq->last_avail_idx -= num;
    vq->inuse -= num;
    return true;
}

static inline
void vring_used_write(VuDev *dev, VuVirtq *vq,
                      struct vring_used_elem *uelem, int i)
{
    struct vring_used *used = vq->vring.used;

    used->ring[i] = *uelem;
    vu_log_write(dev, vq->vring.log_guest_addr +
                 offsetof(struct vring_used, ring[i]),
                 sizeof(used->ring[i]));
}


static void
vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
                  const VuVirtqElement *elem,
                  unsigned int len)
{
    struct vring_desc *desc = vq->vring.desc;
    unsigned int i, max, min;
    unsigned num_bufs = 0;

    max = vq->vring.num;
    i = elem->index;

    if (desc[i].flags & VRING_DESC_F_INDIRECT) {
        if (desc[i].len % sizeof(struct vring_desc)) {
            vu_panic(dev, "Invalid size for indirect buffer table");
        }

        /* loop over the indirect descriptor table */
        max = desc[i].len / sizeof(struct vring_desc);
        desc = vu_gpa_to_va(dev, desc[i].addr);
        i = 0;
    }

    do {
        if (++num_bufs > max) {
            vu_panic(dev, "Looped descriptor");
            return;
        }

        if (desc[i].flags & VRING_DESC_F_WRITE) {
            min = MIN(desc[i].len, len);
            vu_log_write(dev, desc[i].addr, min);
            len -= min;
        }

    } while (len > 0 &&
             (virtqueue_read_next_desc(dev, desc, i, max, &i)
              == VIRTQUEUE_READ_DESC_MORE));
}

void
vu_queue_fill(VuDev *dev, VuVirtq *vq,
              const VuVirtqElement *elem,
              unsigned int len, unsigned int idx)
{
    struct vring_used_elem uelem;

    if (unlikely(dev->broken) ||
        unlikely(!vq->vring.avail)) {
        return;
    }

    vu_log_queue_fill(dev, vq, elem, len);

    idx = (idx + vq->used_idx) % vq->vring.num;

    uelem.id = elem->index;
    uelem.len = len;
    vring_used_write(dev, vq, &uelem, idx);
}

static inline
void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
{
    vq->vring.used->idx = val;
    vu_log_write(dev,
                 vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
                 sizeof(vq->vring.used->idx));

    vq->used_idx = val;
}

void
vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
{
    uint16_t old, new;

    if (unlikely(dev->broken) ||
        unlikely(!vq->vring.avail)) {
        return;
    }

    /* Make sure buffer is written before we update index. */
    smp_wmb();

    old = vq->used_idx;
    new = old + count;
    vring_used_idx_set(dev, vq, new);
    vq->inuse -= count;
    if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
        vq->signalled_used_valid = false;
    }
}

void
vu_queue_push(VuDev *dev, VuVirtq *vq,
              const VuVirtqElement *elem, unsigned int len)
{
    vu_queue_fill(dev, vq, elem, len, 0);
    vu_queue_flush(dev, vq, 1);
}