summaryrefslogtreecommitdiffstats
path: root/drivers/staging/hv/RndisFilter.c
diff options
context:
space:
mode:
authorHank Janssen2009-07-14 00:34:54 +0200
committerGreg Kroah-Hartman2009-09-15 21:01:43 +0200
commitfceaf24a943d8d50391f409ac7fb888cd1f36f32 (patch)
treebbde5063f0208adcb283e29a922774b6ae85a604 /drivers/staging/hv/RndisFilter.c
parentStaging: hv: add the Hyper-V virtual block driver (diff)
downloadkernel-qcow2-linux-fceaf24a943d8d50391f409ac7fb888cd1f36f32.tar.gz
kernel-qcow2-linux-fceaf24a943d8d50391f409ac7fb888cd1f36f32.tar.xz
kernel-qcow2-linux-fceaf24a943d8d50391f409ac7fb888cd1f36f32.zip
Staging: hv: add the Hyper-V virtual network driver
This is the virtual network driver when running Linux on top of Hyper-V. Signed-off-by: Hank Janssen <hjanssen@microsoft.com> Signed-off-by: Haiyang Zhang <haiyangz@microsoft.com> Signed-off-by: Greg Kroah-Hartman <gregkh@suse.de>
Diffstat (limited to 'drivers/staging/hv/RndisFilter.c')
-rw-r--r--drivers/staging/hv/RndisFilter.c1162
1 files changed, 1162 insertions, 0 deletions
diff --git a/drivers/staging/hv/RndisFilter.c b/drivers/staging/hv/RndisFilter.c
new file mode 100644
index 000000000000..57b828b12c14
--- /dev/null
+++ b/drivers/staging/hv/RndisFilter.c
@@ -0,0 +1,1162 @@
+/*
+ *
+ * Copyright (c) 2009, Microsoft Corporation.
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms and conditions of the GNU General Public License,
+ * version 2, as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * this program; if not, write to the Free Software Foundation, Inc., 59 Temple
+ * Place - Suite 330, Boston, MA 02111-1307 USA.
+ *
+ * Authors:
+ * Haiyang Zhang <haiyangz@microsoft.com>
+ * Hank Janssen <hjanssen@microsoft.com>
+ *
+ */
+
+
+#include "logging.h"
+
+#include "NetVscApi.h"
+#include "RndisFilter.h"
+
+//
+// Data types
+//
+
+typedef struct _RNDIS_FILTER_DRIVER_OBJECT {
+ // The original driver
+ NETVSC_DRIVER_OBJECT InnerDriver;
+
+} RNDIS_FILTER_DRIVER_OBJECT;
+
+typedef enum {
+ RNDIS_DEV_UNINITIALIZED = 0,
+ RNDIS_DEV_INITIALIZING,
+ RNDIS_DEV_INITIALIZED,
+ RNDIS_DEV_DATAINITIALIZED,
+} RNDIS_DEVICE_STATE;
+
+typedef struct _RNDIS_DEVICE {
+ NETVSC_DEVICE *NetDevice;
+
+ RNDIS_DEVICE_STATE State;
+ UINT32 LinkStatus;
+ UINT32 NewRequestId;
+
+ HANDLE RequestLock;
+ LIST_ENTRY RequestList;
+
+ UCHAR HwMacAddr[HW_MACADDR_LEN];
+} RNDIS_DEVICE;
+
+
+typedef struct _RNDIS_REQUEST {
+ LIST_ENTRY ListEntry;
+ HANDLE WaitEvent;
+
+ // FIXME: We assumed a fixed size response here. If we do ever need to handle a bigger response,
+ // we can either define a max response message or add a response buffer variable above this field
+ RNDIS_MESSAGE ResponseMessage;
+
+ // Simplify allocation by having a netvsc packet inline
+ NETVSC_PACKET Packet;
+ PAGE_BUFFER Buffer;
+ // FIXME: We assumed a fixed size request here.
+ RNDIS_MESSAGE RequestMessage;
+} RNDIS_REQUEST;
+
+
+typedef struct _RNDIS_FILTER_PACKET {
+ void *CompletionContext;
+ PFN_ON_SENDRECVCOMPLETION OnCompletion;
+
+ RNDIS_MESSAGE Message;
+} RNDIS_FILTER_PACKET;
+
+//
+// Internal routines
+//
+static int
+RndisFilterSendRequest(
+ RNDIS_DEVICE *Device,
+ RNDIS_REQUEST *Request
+ );
+
+static void
+RndisFilterReceiveResponse(
+ RNDIS_DEVICE *Device,
+ RNDIS_MESSAGE *Response
+ );
+
+static void
+RndisFilterReceiveIndicateStatus(
+ RNDIS_DEVICE *Device,
+ RNDIS_MESSAGE *Response
+ );
+
+static void
+RndisFilterReceiveData(
+ RNDIS_DEVICE *Device,
+ RNDIS_MESSAGE *Message,
+ NETVSC_PACKET *Packet
+ );
+
+static int
+RndisFilterOnReceive(
+ DEVICE_OBJECT *Device,
+ NETVSC_PACKET *Packet
+ );
+
+static int
+RndisFilterQueryDevice(
+ RNDIS_DEVICE *Device,
+ UINT32 Oid,
+ VOID *Result,
+ UINT32 *ResultSize
+ );
+
+static inline int
+RndisFilterQueryDeviceMac(
+ RNDIS_DEVICE *Device
+ );
+
+static inline int
+RndisFilterQueryDeviceLinkStatus(
+ RNDIS_DEVICE *Device
+ );
+
+static int
+RndisFilterSetPacketFilter(
+ RNDIS_DEVICE *Device,
+ UINT32 NewFilter
+ );
+
+static int
+RndisFilterInitDevice(
+ RNDIS_DEVICE *Device
+ );
+
+static int
+RndisFilterOpenDevice(
+ RNDIS_DEVICE *Device
+ );
+
+static int
+RndisFilterCloseDevice(
+ RNDIS_DEVICE *Device
+ );
+
+static int
+RndisFilterOnDeviceAdd(
+ DEVICE_OBJECT *Device,
+ void *AdditionalInfo
+ );
+
+static int
+RndisFilterOnDeviceRemove(
+ DEVICE_OBJECT *Device
+ );
+
+static void
+RndisFilterOnCleanup(
+ DRIVER_OBJECT *Driver
+ );
+
+static int
+RndisFilterOnOpen(
+ DEVICE_OBJECT *Device
+ );
+
+static int
+RndisFilterOnClose(
+ DEVICE_OBJECT *Device
+ );
+
+static int
+RndisFilterOnSend(
+ DEVICE_OBJECT *Device,
+ NETVSC_PACKET *Packet
+ );
+
+static void
+RndisFilterOnSendCompletion(
+ void *Context
+ );
+
+static void
+RndisFilterOnSendRequestCompletion(
+ void *Context
+ );
+
+//
+// Global var
+//
+
+// The one and only
+RNDIS_FILTER_DRIVER_OBJECT gRndisFilter;
+
+static inline RNDIS_DEVICE* GetRndisDevice(void)
+{
+ RNDIS_DEVICE *device;
+
+ device = MemAllocZeroed(sizeof(RNDIS_DEVICE));
+ if (!device)
+ {
+ return NULL;
+ }
+
+ device->RequestLock = SpinlockCreate();
+ if (!device->RequestLock)
+ {
+ MemFree(device);
+ return NULL;
+ }
+
+ INITIALIZE_LIST_HEAD(&device->RequestList);
+
+ device->State = RNDIS_DEV_UNINITIALIZED;
+
+ return device;
+}
+
+static inline void PutRndisDevice(RNDIS_DEVICE *Device)
+{
+ SpinlockClose(Device->RequestLock);
+ MemFree(Device);
+}
+
+static inline RNDIS_REQUEST* GetRndisRequest(RNDIS_DEVICE *Device, UINT32 MessageType, UINT32 MessageLength)
+{
+ RNDIS_REQUEST *request;
+ RNDIS_MESSAGE *rndisMessage;
+ RNDIS_SET_REQUEST *set;
+
+ request = MemAllocZeroed(sizeof(RNDIS_REQUEST));
+ if (!request)
+ {
+ return NULL;
+ }
+
+ request->WaitEvent = WaitEventCreate();
+ if (!request->WaitEvent)
+ {
+ MemFree(request);
+ return NULL;
+ }
+
+ rndisMessage = &request->RequestMessage;
+ rndisMessage->NdisMessageType = MessageType;
+ rndisMessage->MessageLength = MessageLength;
+
+ // Set the request id. This field is always after the rndis header for request/response packet types so
+ // we just used the SetRequest as a template
+ set = &rndisMessage->Message.SetRequest;
+ set->RequestId = InterlockedIncrement((int*)&Device->NewRequestId);
+
+ // Add to the request list
+ SpinlockAcquire(Device->RequestLock);
+ INSERT_TAIL_LIST(&Device->RequestList, &request->ListEntry);
+ SpinlockRelease(Device->RequestLock);
+
+ return request;
+}
+
+static inline void PutRndisRequest(RNDIS_DEVICE *Device, RNDIS_REQUEST *Request)
+{
+ SpinlockAcquire(Device->RequestLock);
+ REMOVE_ENTRY_LIST(&Request->ListEntry);
+ SpinlockRelease(Device->RequestLock);
+
+ WaitEventClose(Request->WaitEvent);
+ MemFree(Request);
+}
+
+static inline void DumpRndisMessage(RNDIS_MESSAGE *RndisMessage)
+{
+ switch (RndisMessage->NdisMessageType)
+ {
+ case REMOTE_NDIS_PACKET_MSG:
+ DPRINT_DBG(NETVSC, "REMOTE_NDIS_PACKET_MSG (len %u, data offset %u data len %u, # oob %u, oob offset %u, oob len %u, pkt offset %u, pkt len %u",
+ RndisMessage->MessageLength,
+ RndisMessage->Message.Packet.DataOffset,
+ RndisMessage->Message.Packet.DataLength,
+ RndisMessage->Message.Packet.NumOOBDataElements,
+ RndisMessage->Message.Packet.OOBDataOffset,
+ RndisMessage->Message.Packet.OOBDataLength,
+ RndisMessage->Message.Packet.PerPacketInfoOffset,
+ RndisMessage->Message.Packet.PerPacketInfoLength);
+ break;
+
+ case REMOTE_NDIS_INITIALIZE_CMPLT:
+ DPRINT_DBG(NETVSC, "REMOTE_NDIS_INITIALIZE_CMPLT (len %u, id 0x%x, status 0x%x, major %d, minor %d, device flags %d, max xfer size 0x%x, max pkts %u, pkt aligned %u)",
+ RndisMessage->MessageLength,
+ RndisMessage->Message.InitializeComplete.RequestId,
+ RndisMessage->Message.InitializeComplete.Status,
+ RndisMessage->Message.InitializeComplete.MajorVersion,
+ RndisMessage->Message.InitializeComplete.MinorVersion,
+ RndisMessage->Message.InitializeComplete.DeviceFlags,
+ RndisMessage->Message.InitializeComplete.MaxTransferSize,
+ RndisMessage->Message.InitializeComplete.MaxPacketsPerMessage,
+ RndisMessage->Message.InitializeComplete.PacketAlignmentFactor);
+ break;
+
+ case REMOTE_NDIS_QUERY_CMPLT:
+ DPRINT_DBG(NETVSC, "REMOTE_NDIS_QUERY_CMPLT (len %u, id 0x%x, status 0x%x, buf len %u, buf offset %u)",
+ RndisMessage->MessageLength,
+ RndisMessage->Message.QueryComplete.RequestId,
+ RndisMessage->Message.QueryComplete.Status,
+ RndisMessage->Message.QueryComplete.InformationBufferLength,
+ RndisMessage->Message.QueryComplete.InformationBufferOffset);
+ break;
+
+ case REMOTE_NDIS_SET_CMPLT:
+ DPRINT_DBG(NETVSC, "REMOTE_NDIS_SET_CMPLT (len %u, id 0x%x, status 0x%x)",
+ RndisMessage->MessageLength,
+ RndisMessage->Message.SetComplete.RequestId,
+ RndisMessage->Message.SetComplete.Status);
+ break;
+
+ case REMOTE_NDIS_INDICATE_STATUS_MSG:
+ DPRINT_DBG(NETVSC, "REMOTE_NDIS_INDICATE_STATUS_MSG (len %u, status 0x%x, buf len %u, buf offset %u)",
+ RndisMessage->MessageLength,
+ RndisMessage->Message.IndicateStatus.Status,
+ RndisMessage->Message.IndicateStatus.StatusBufferLength,
+ RndisMessage->Message.IndicateStatus.StatusBufferOffset);
+ break;
+
+ default:
+ DPRINT_DBG(NETVSC, "0x%x (len %u)",
+ RndisMessage->NdisMessageType,
+ RndisMessage->MessageLength);
+ break;
+ }
+}
+
+static int
+RndisFilterSendRequest(
+ RNDIS_DEVICE *Device,
+ RNDIS_REQUEST *Request
+ )
+{
+ int ret=0;
+ NETVSC_PACKET *packet;
+
+ DPRINT_ENTER(NETVSC);
+
+ // Setup the packet to send it
+ packet = &Request->Packet;
+
+ packet->IsDataPacket = FALSE;
+ packet->TotalDataBufferLength = Request->RequestMessage.MessageLength;
+ packet->PageBufferCount = 1;
+
+ packet->PageBuffers[0].Pfn = GetPhysicalAddress(&Request->RequestMessage) >> PAGE_SHIFT;
+ packet->PageBuffers[0].Length = Request->RequestMessage.MessageLength;
+ packet->PageBuffers[0].Offset = (ULONG_PTR)&Request->RequestMessage & (PAGE_SIZE -1);
+
+ packet->Completion.Send.SendCompletionContext = Request;//packet;
+ packet->Completion.Send.OnSendCompletion = RndisFilterOnSendRequestCompletion;
+ packet->Completion.Send.SendCompletionTid = (ULONG_PTR)Device;
+
+ ret = gRndisFilter.InnerDriver.OnSend(Device->NetDevice->Device, packet);
+ DPRINT_EXIT(NETVSC);
+ return ret;
+}
+
+
+static void
+RndisFilterReceiveResponse(
+ RNDIS_DEVICE *Device,
+ RNDIS_MESSAGE *Response
+ )
+{
+ LIST_ENTRY *anchor;
+ LIST_ENTRY *curr;
+ RNDIS_REQUEST *request=NULL;
+ BOOL found=FALSE;
+
+ DPRINT_ENTER(NETVSC);
+
+ SpinlockAcquire(Device->RequestLock);
+ ITERATE_LIST_ENTRIES(anchor, curr, &Device->RequestList)
+ {
+ request = CONTAINING_RECORD(curr, RNDIS_REQUEST, ListEntry);
+
+ // All request/response message contains RequestId as the 1st field
+ if (request->RequestMessage.Message.InitializeRequest.RequestId == Response->Message.InitializeComplete.RequestId)
+ {
+ DPRINT_DBG(NETVSC, "found rndis request for this response (id 0x%x req type 0x%x res type 0x%x)",
+ request->RequestMessage.Message.InitializeRequest.RequestId, request->RequestMessage.NdisMessageType, Response->NdisMessageType);
+
+ found = TRUE;
+ break;
+ }
+ }
+ SpinlockRelease(Device->RequestLock);
+
+ if (found)
+ {
+ if (Response->MessageLength <= sizeof(RNDIS_MESSAGE))
+ {
+ memcpy(&request->ResponseMessage, Response, Response->MessageLength);
+ }
+ else
+ {
+ DPRINT_ERR(NETVSC, "rndis response buffer overflow detected (size %u max %u)", Response->MessageLength, sizeof(RNDIS_FILTER_PACKET));
+
+ if (Response->NdisMessageType == REMOTE_NDIS_RESET_CMPLT) // does not have a request id field
+ {
+ request->ResponseMessage.Message.ResetComplete.Status = STATUS_BUFFER_OVERFLOW;
+ }
+ else
+ {
+ request->ResponseMessage.Message.InitializeComplete.Status = STATUS_BUFFER_OVERFLOW;
+ }
+ }
+
+ WaitEventSet(request->WaitEvent);
+ }
+ else
+ {
+ DPRINT_ERR(NETVSC, "no rndis request found for this response (id 0x%x res type 0x%x)",
+ Response->Message.InitializeComplete.RequestId, Response->NdisMessageType);
+ }
+
+ DPRINT_EXIT(NETVSC);
+}
+
+static void
+RndisFilterReceiveIndicateStatus(
+ RNDIS_DEVICE *Device,
+ RNDIS_MESSAGE *Response
+ )
+{
+ RNDIS_INDICATE_STATUS *indicate = &Response->Message.IndicateStatus;
+
+ if (indicate->Status == RNDIS_STATUS_MEDIA_CONNECT)
+ {
+ gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 1);
+ }
+ else if (indicate->Status == RNDIS_STATUS_MEDIA_DISCONNECT)
+ {
+ gRndisFilter.InnerDriver.OnLinkStatusChanged(Device->NetDevice->Device, 0);
+ }
+ else
+ {
+ // TODO:
+ }
+}
+
+static void
+RndisFilterReceiveData(
+ RNDIS_DEVICE *Device,
+ RNDIS_MESSAGE *Message,
+ NETVSC_PACKET *Packet
+ )
+{
+ RNDIS_PACKET *rndisPacket;
+ UINT32 dataOffset;
+
+ DPRINT_ENTER(NETVSC);
+
+ // empty ethernet frame ??
+ ASSERT(Packet->PageBuffers[0].Length > RNDIS_MESSAGE_SIZE(RNDIS_PACKET));
+
+ rndisPacket = &Message->Message.Packet;
+
+ // FIXME: Handle multiple rndis pkt msgs that maybe enclosed in this
+ // netvsc packet (ie TotalDataBufferLength != MessageLength)
+
+ // Remove the rndis header and pass it back up the stack
+ dataOffset = RNDIS_HEADER_SIZE + rndisPacket->DataOffset;
+
+ Packet->TotalDataBufferLength -= dataOffset;
+ Packet->PageBuffers[0].Offset += dataOffset;
+ Packet->PageBuffers[0].Length -= dataOffset;
+
+ Packet->IsDataPacket = TRUE;
+
+ gRndisFilter.InnerDriver.OnReceiveCallback(Device->NetDevice->Device, Packet);
+
+ DPRINT_EXIT(NETVSC);
+}
+
+static int
+RndisFilterOnReceive(
+ DEVICE_OBJECT *Device,
+ NETVSC_PACKET *Packet
+ )
+{
+ NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
+ RNDIS_DEVICE *rndisDevice;
+ RNDIS_MESSAGE rndisMessage;
+ RNDIS_MESSAGE *rndisHeader;
+
+ DPRINT_ENTER(NETVSC);
+
+ ASSERT(netDevice);
+ //Make sure the rndis device state is initialized
+ if (!netDevice->Extension)
+ {
+ DPRINT_ERR(NETVSC, "got rndis message but no rndis device...dropping this message!");
+ DPRINT_EXIT(NETVSC);
+ return -1;
+ }
+
+ rndisDevice = (RNDIS_DEVICE*)netDevice->Extension;
+ if (rndisDevice->State == RNDIS_DEV_UNINITIALIZED)
+ {
+ DPRINT_ERR(NETVSC, "got rndis message but rndis device uninitialized...dropping this message!");
+ DPRINT_EXIT(NETVSC);
+ return -1;
+ }
+
+ rndisHeader = (RNDIS_MESSAGE*)PageMapVirtualAddress(Packet->PageBuffers[0].Pfn);
+
+ rndisHeader = (void*)((ULONG_PTR)rndisHeader + Packet->PageBuffers[0].Offset);
+
+ // Make sure we got a valid rndis message
+ // FIXME: There seems to be a bug in set completion msg where its MessageLength is 16 bytes but
+ // the ByteCount field in the xfer page range shows 52 bytes
+#if 0
+ if ( Packet->TotalDataBufferLength != rndisHeader->MessageLength )
+ {
+ PageUnmapVirtualAddress((void*)(ULONG_PTR)rndisHeader - Packet->PageBuffers[0].Offset);
+
+ DPRINT_ERR(NETVSC, "invalid rndis message? (expected %u bytes got %u)...dropping this message!",
+ rndisHeader->MessageLength, Packet->TotalDataBufferLength);
+ DPRINT_EXIT(NETVSC);
+ return -1;
+ }
+#endif
+
+ if ((rndisHeader->NdisMessageType != REMOTE_NDIS_PACKET_MSG) && (rndisHeader->MessageLength > sizeof(RNDIS_MESSAGE)))
+ {
+ DPRINT_ERR(NETVSC, "incoming rndis message buffer overflow detected (got %u, max %u)...marking it an error!",
+ rndisHeader->MessageLength, sizeof(RNDIS_MESSAGE));
+ }
+
+ memcpy(&rndisMessage, rndisHeader, (rndisHeader->MessageLength > sizeof(RNDIS_MESSAGE))?sizeof(RNDIS_MESSAGE):rndisHeader->MessageLength);
+
+ PageUnmapVirtualAddress((void*)(ULONG_PTR)rndisHeader - Packet->PageBuffers[0].Offset);
+
+ DumpRndisMessage(&rndisMessage);
+
+ switch (rndisMessage.NdisMessageType)
+ {
+ // data msg
+ case REMOTE_NDIS_PACKET_MSG:
+ RndisFilterReceiveData(rndisDevice, &rndisMessage, Packet);
+ break;
+
+ // completion msgs
+ case REMOTE_NDIS_INITIALIZE_CMPLT:
+ case REMOTE_NDIS_QUERY_CMPLT:
+ case REMOTE_NDIS_SET_CMPLT:
+ //case REMOTE_NDIS_RESET_CMPLT:
+ //case REMOTE_NDIS_KEEPALIVE_CMPLT:
+ RndisFilterReceiveResponse(rndisDevice, &rndisMessage);
+ break;
+
+ // notification msgs
+ case REMOTE_NDIS_INDICATE_STATUS_MSG:
+ RndisFilterReceiveIndicateStatus(rndisDevice, &rndisMessage);
+ break;
+ default:
+ DPRINT_ERR(NETVSC, "unhandled rndis message (type %u len %u)", rndisMessage.NdisMessageType, rndisMessage.MessageLength);
+ break;
+ }
+
+ DPRINT_EXIT(NETVSC);
+ return 0;
+}
+
+
+static int
+RndisFilterQueryDevice(
+ RNDIS_DEVICE *Device,
+ UINT32 Oid,
+ VOID *Result,
+ UINT32 *ResultSize
+ )
+{
+ RNDIS_REQUEST *request;
+ UINT32 inresultSize = *ResultSize;
+ RNDIS_QUERY_REQUEST *query;
+ RNDIS_QUERY_COMPLETE *queryComplete;
+ int ret=0;
+
+ DPRINT_ENTER(NETVSC);
+
+ ASSERT(Result);
+
+ *ResultSize = 0;
+ request = GetRndisRequest(Device, REMOTE_NDIS_QUERY_MSG, RNDIS_MESSAGE_SIZE(RNDIS_QUERY_REQUEST));
+ if (!request)
+ {
+ ret = -1;
+ goto Cleanup;
+ }
+
+ // Setup the rndis query
+ query = &request->RequestMessage.Message.QueryRequest;
+ query->Oid = Oid;
+ query->InformationBufferOffset = sizeof(RNDIS_QUERY_REQUEST);
+ query->InformationBufferLength = 0;
+ query->DeviceVcHandle = 0;
+
+ ret = RndisFilterSendRequest(Device, request);
+ if (ret != 0)
+ {
+ goto Cleanup;
+ }
+
+ WaitEventWait(request->WaitEvent);
+
+ // Copy the response back
+ queryComplete = &request->ResponseMessage.Message.QueryComplete;
+
+ if (queryComplete->InformationBufferLength > inresultSize)
+ {
+ ret = -1;
+ goto Cleanup;
+ }
+
+ memcpy(Result,
+ (void*)((ULONG_PTR)queryComplete + queryComplete->InformationBufferOffset),
+ queryComplete->InformationBufferLength);
+
+ *ResultSize = queryComplete->InformationBufferLength;
+
+Cleanup:
+ if (request)
+ {
+ PutRndisRequest(Device, request);
+ }
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+static inline int
+RndisFilterQueryDeviceMac(
+ RNDIS_DEVICE *Device
+ )
+{
+ UINT32 size=HW_MACADDR_LEN;
+
+ return RndisFilterQueryDevice(Device,
+ RNDIS_OID_802_3_PERMANENT_ADDRESS,
+ Device->HwMacAddr,
+ &size);
+}
+
+static inline int
+RndisFilterQueryDeviceLinkStatus(
+ RNDIS_DEVICE *Device
+ )
+{
+ UINT32 size=sizeof(UINT32);
+
+ return RndisFilterQueryDevice(Device,
+ RNDIS_OID_GEN_MEDIA_CONNECT_STATUS,
+ &Device->LinkStatus,
+ &size);
+}
+
+static int
+RndisFilterSetPacketFilter(
+ RNDIS_DEVICE *Device,
+ UINT32 NewFilter
+ )
+{
+ RNDIS_REQUEST *request;
+ RNDIS_SET_REQUEST *set;
+ RNDIS_SET_COMPLETE *setComplete;
+ UINT32 status;
+ int ret;
+
+ DPRINT_ENTER(NETVSC);
+
+ ASSERT(RNDIS_MESSAGE_SIZE(RNDIS_SET_REQUEST) + sizeof(UINT32) <= sizeof(RNDIS_MESSAGE));
+
+ request = GetRndisRequest(Device, REMOTE_NDIS_SET_MSG, RNDIS_MESSAGE_SIZE(RNDIS_SET_REQUEST) + sizeof(UINT32));
+ if (!request)
+ {
+ ret = -1;
+ goto Cleanup;
+ }
+
+ // Setup the rndis set
+ set = &request->RequestMessage.Message.SetRequest;
+ set->Oid = RNDIS_OID_GEN_CURRENT_PACKET_FILTER;
+ set->InformationBufferLength = sizeof(UINT32);
+ set->InformationBufferOffset = sizeof(RNDIS_SET_REQUEST);
+
+ memcpy((void*)(ULONG_PTR)set + sizeof(RNDIS_SET_REQUEST), &NewFilter, sizeof(UINT32));
+
+ ret = RndisFilterSendRequest(Device, request);
+ if (ret != 0)
+ {
+ goto Cleanup;
+ }
+
+ ret = WaitEventWaitEx(request->WaitEvent, 2000/*2sec*/);
+ if (!ret)
+ {
+ ret = -1;
+ DPRINT_ERR(NETVSC, "timeout before we got a set response...");
+ // We cant deallocate the request since we may still receive a send completion for it.
+ goto Exit;
+ }
+ else
+ {
+ if (ret > 0)
+ {
+ ret = 0;
+ }
+ setComplete = &request->ResponseMessage.Message.SetComplete;
+ status = setComplete->Status;
+ }
+
+Cleanup:
+ if (request)
+ {
+ PutRndisRequest(Device, request);
+ }
+Exit:
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+int
+RndisFilterInit(
+ NETVSC_DRIVER_OBJECT *Driver
+ )
+{
+ DPRINT_ENTER(NETVSC);
+
+ DPRINT_DBG(NETVSC, "sizeof(RNDIS_FILTER_PACKET) == %d", sizeof(RNDIS_FILTER_PACKET));
+
+ Driver->RequestExtSize = sizeof(RNDIS_FILTER_PACKET);
+ Driver->AdditionalRequestPageBufferCount = 1; // For rndis header
+
+ //Driver->Context = rndisDriver;
+
+ memset(&gRndisFilter, 0, sizeof(RNDIS_FILTER_DRIVER_OBJECT));
+
+ /*rndisDriver->Driver = Driver;
+
+ ASSERT(Driver->OnLinkStatusChanged);
+ rndisDriver->OnLinkStatusChanged = Driver->OnLinkStatusChanged;*/
+
+ // Save the original dispatch handlers before we override it
+ gRndisFilter.InnerDriver.Base.OnDeviceAdd = Driver->Base.OnDeviceAdd;
+ gRndisFilter.InnerDriver.Base.OnDeviceRemove = Driver->Base.OnDeviceRemove;
+ gRndisFilter.InnerDriver.Base.OnCleanup = Driver->Base.OnCleanup;
+
+ ASSERT(Driver->OnSend);
+ ASSERT(Driver->OnReceiveCallback);
+ gRndisFilter.InnerDriver.OnSend = Driver->OnSend;
+ gRndisFilter.InnerDriver.OnReceiveCallback = Driver->OnReceiveCallback;
+ gRndisFilter.InnerDriver.OnLinkStatusChanged = Driver->OnLinkStatusChanged;
+
+ // Override
+ Driver->Base.OnDeviceAdd = RndisFilterOnDeviceAdd;
+ Driver->Base.OnDeviceRemove = RndisFilterOnDeviceRemove;
+ Driver->Base.OnCleanup = RndisFilterOnCleanup;
+ Driver->OnSend = RndisFilterOnSend;
+ Driver->OnOpen = RndisFilterOnOpen;
+ Driver->OnClose = RndisFilterOnClose;
+ //Driver->QueryLinkStatus = RndisFilterQueryDeviceLinkStatus;
+ Driver->OnReceiveCallback = RndisFilterOnReceive;
+
+ DPRINT_EXIT(NETVSC);
+
+ return 0;
+}
+
+static int
+RndisFilterInitDevice(
+ RNDIS_DEVICE *Device
+ )
+{
+ RNDIS_REQUEST *request;
+ RNDIS_INITIALIZE_REQUEST *init;
+ RNDIS_INITIALIZE_COMPLETE *initComplete;
+ UINT32 status;
+ int ret;
+
+ DPRINT_ENTER(NETVSC);
+
+ request = GetRndisRequest(Device, REMOTE_NDIS_INITIALIZE_MSG, RNDIS_MESSAGE_SIZE(RNDIS_INITIALIZE_REQUEST));
+ if (!request)
+ {
+ ret = -1;
+ goto Cleanup;
+ }
+
+ // Setup the rndis set
+ init = &request->RequestMessage.Message.InitializeRequest;
+ init->MajorVersion = RNDIS_MAJOR_VERSION;
+ init->MinorVersion = RNDIS_MINOR_VERSION;
+ init->MaxTransferSize = 2048; // FIXME: Use 1536 - rounded ethernet frame size
+
+ Device->State = RNDIS_DEV_INITIALIZING;
+
+ ret = RndisFilterSendRequest(Device, request);
+ if (ret != 0)
+ {
+ Device->State = RNDIS_DEV_UNINITIALIZED;
+ goto Cleanup;
+ }
+
+ WaitEventWait(request->WaitEvent);
+
+ initComplete = &request->ResponseMessage.Message.InitializeComplete;
+ status = initComplete->Status;
+ if (status == RNDIS_STATUS_SUCCESS)
+ {
+ Device->State = RNDIS_DEV_INITIALIZED;
+ ret = 0;
+ }
+ else
+ {
+ Device->State = RNDIS_DEV_UNINITIALIZED;
+ ret = -1;
+ }
+
+Cleanup:
+ if (request)
+ {
+ PutRndisRequest(Device, request);
+ }
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+static void
+RndisFilterHaltDevice(
+ RNDIS_DEVICE *Device
+ )
+{
+ RNDIS_REQUEST *request;
+ RNDIS_HALT_REQUEST *halt;
+
+ DPRINT_ENTER(NETVSC);
+
+ // Attempt to do a rndis device halt
+ request = GetRndisRequest(Device, REMOTE_NDIS_HALT_MSG, RNDIS_MESSAGE_SIZE(RNDIS_HALT_REQUEST));
+ if (!request)
+ {
+ goto Cleanup;
+ }
+
+ // Setup the rndis set
+ halt = &request->RequestMessage.Message.HaltRequest;
+ halt->RequestId = InterlockedIncrement((int*)&Device->NewRequestId);
+
+ // Ignore return since this msg is optional.
+ RndisFilterSendRequest(Device, request);
+
+ Device->State = RNDIS_DEV_UNINITIALIZED;
+
+Cleanup:
+ if (request)
+ {
+ PutRndisRequest(Device, request);
+ }
+ DPRINT_EXIT(NETVSC);
+ return;
+}
+
+
+static int
+RndisFilterOpenDevice(
+ RNDIS_DEVICE *Device
+ )
+{
+ int ret=0;
+
+ DPRINT_ENTER(NETVSC);
+
+ if (Device->State != RNDIS_DEV_INITIALIZED)
+ return 0;
+
+ ret = RndisFilterSetPacketFilter(Device, NDIS_PACKET_TYPE_BROADCAST|NDIS_PACKET_TYPE_DIRECTED);
+ if (ret == 0)
+ {
+ Device->State = RNDIS_DEV_DATAINITIALIZED;
+ }
+
+ DPRINT_EXIT(NETVSC);
+ return ret;
+}
+
+static int
+RndisFilterCloseDevice(
+ RNDIS_DEVICE *Device
+ )
+{
+ int ret;
+
+ DPRINT_ENTER(NETVSC);
+
+ if (Device->State != RNDIS_DEV_DATAINITIALIZED)
+ return 0;
+
+ ret = RndisFilterSetPacketFilter(Device, 0);
+ if (ret == 0)
+ {
+ Device->State = RNDIS_DEV_INITIALIZED;
+ }
+
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+
+int
+RndisFilterOnDeviceAdd(
+ DEVICE_OBJECT *Device,
+ void *AdditionalInfo
+ )
+{
+ int ret;
+ NETVSC_DEVICE *netDevice;
+ RNDIS_DEVICE *rndisDevice;
+ NETVSC_DEVICE_INFO *deviceInfo = (NETVSC_DEVICE_INFO*)AdditionalInfo;
+
+ DPRINT_ENTER(NETVSC);
+
+ //rndisDevice = MemAlloc(sizeof(RNDIS_DEVICE));
+ rndisDevice = GetRndisDevice();
+ if (!rndisDevice)
+ {
+ DPRINT_EXIT(NETVSC);
+ return -1;
+ }
+
+ DPRINT_DBG(NETVSC, "rndis device object allocated - %p", rndisDevice);
+
+ // Let the inner driver handle this first to create the netvsc channel
+ // NOTE! Once the channel is created, we may get a receive callback
+ // (RndisFilterOnReceive()) before this call is completed
+ ret = gRndisFilter.InnerDriver.Base.OnDeviceAdd(Device, AdditionalInfo);
+ if (ret != 0)
+ {
+ PutRndisDevice(rndisDevice);
+ DPRINT_EXIT(NETVSC);
+ return ret;
+ }
+
+ //
+ // Initialize the rndis device
+ //
+ netDevice = (NETVSC_DEVICE*)Device->Extension;
+ ASSERT(netDevice);
+ ASSERT(netDevice->Device);
+
+ netDevice->Extension = rndisDevice;
+ rndisDevice->NetDevice = netDevice;
+
+ // Send the rndis initialization message
+ ret = RndisFilterInitDevice(rndisDevice);
+ if (ret != 0)
+ {
+ // TODO: If rndis init failed, we will need to shut down the channel
+ }
+
+ // Get the mac address
+ ret = RndisFilterQueryDeviceMac(rndisDevice);
+ if (ret != 0)
+ {
+ // TODO: shutdown rndis device and the channel
+ }
+
+ DPRINT_INFO(NETVSC, "Device 0x%p mac addr %02x%02x%02x%02x%02x%02x",
+ rndisDevice,
+ rndisDevice->HwMacAddr[0],
+ rndisDevice->HwMacAddr[1],
+ rndisDevice->HwMacAddr[2],
+ rndisDevice->HwMacAddr[3],
+ rndisDevice->HwMacAddr[4],
+ rndisDevice->HwMacAddr[5]);
+
+ memcpy(deviceInfo->MacAddr, rndisDevice->HwMacAddr, HW_MACADDR_LEN);
+
+ RndisFilterQueryDeviceLinkStatus(rndisDevice);
+
+ deviceInfo->LinkState = rndisDevice->LinkStatus;
+ DPRINT_INFO(NETVSC, "Device 0x%p link state %s", rndisDevice, ((deviceInfo->LinkState)?("down"):("up")));
+
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+
+static int
+RndisFilterOnDeviceRemove(
+ DEVICE_OBJECT *Device
+ )
+{
+ NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
+ RNDIS_DEVICE *rndisDevice = (RNDIS_DEVICE*)netDevice->Extension;
+
+ DPRINT_ENTER(NETVSC);
+
+ // Halt and release the rndis device
+ RndisFilterHaltDevice(rndisDevice);
+
+ PutRndisDevice(rndisDevice);
+ netDevice->Extension = NULL;
+
+ // Pass control to inner driver to remove the device
+ gRndisFilter.InnerDriver.Base.OnDeviceRemove(Device);
+
+ DPRINT_EXIT(NETVSC);
+
+ return 0;
+}
+
+
+static void
+RndisFilterOnCleanup(
+ DRIVER_OBJECT *Driver
+ )
+{
+ DPRINT_ENTER(NETVSC);
+
+ DPRINT_EXIT(NETVSC);
+}
+
+static int
+RndisFilterOnOpen(
+ DEVICE_OBJECT *Device
+ )
+{
+ int ret;
+ NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
+
+ DPRINT_ENTER(NETVSC);
+
+ ASSERT(netDevice);
+ ret = RndisFilterOpenDevice((RNDIS_DEVICE*)netDevice->Extension);
+
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+static int
+RndisFilterOnClose(
+ DEVICE_OBJECT *Device
+ )
+{
+ int ret;
+ NETVSC_DEVICE *netDevice = (NETVSC_DEVICE*)Device->Extension;
+
+ DPRINT_ENTER(NETVSC);
+
+ ASSERT(netDevice);
+ ret = RndisFilterCloseDevice((RNDIS_DEVICE*)netDevice->Extension);
+
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+
+static int
+RndisFilterOnSend(
+ DEVICE_OBJECT *Device,
+ NETVSC_PACKET *Packet
+ )
+{
+ int ret=0;
+ RNDIS_FILTER_PACKET *filterPacket;
+ RNDIS_MESSAGE *rndisMessage;
+ RNDIS_PACKET *rndisPacket;
+ UINT32 rndisMessageSize;
+
+ DPRINT_ENTER(NETVSC);
+
+ // Add the rndis header
+ filterPacket = (RNDIS_FILTER_PACKET*)Packet->Extension;
+ ASSERT(filterPacket);
+
+ memset(filterPacket, 0, sizeof(RNDIS_FILTER_PACKET));
+
+ rndisMessage = &filterPacket->Message;
+ rndisMessageSize = RNDIS_MESSAGE_SIZE(RNDIS_PACKET);
+
+ rndisMessage->NdisMessageType = REMOTE_NDIS_PACKET_MSG;
+ rndisMessage->MessageLength = Packet->TotalDataBufferLength + rndisMessageSize;
+
+ rndisPacket = &rndisMessage->Message.Packet;
+ rndisPacket->DataOffset = sizeof(RNDIS_PACKET);
+ rndisPacket->DataLength = Packet->TotalDataBufferLength;
+
+ Packet->IsDataPacket = TRUE;
+ Packet->PageBuffers[0].Pfn = GetPhysicalAddress(rndisMessage) >> PAGE_SHIFT;
+ Packet->PageBuffers[0].Offset = (ULONG_PTR)rndisMessage & (PAGE_SIZE-1);
+ Packet->PageBuffers[0].Length = rndisMessageSize;
+
+ // Save the packet send completion and context
+ filterPacket->OnCompletion = Packet->Completion.Send.OnSendCompletion;
+ filterPacket->CompletionContext = Packet->Completion.Send.SendCompletionContext;
+
+ // Use ours
+ Packet->Completion.Send.OnSendCompletion = RndisFilterOnSendCompletion;
+ Packet->Completion.Send.SendCompletionContext = filterPacket;
+
+ ret = gRndisFilter.InnerDriver.OnSend(Device, Packet);
+ if (ret != 0)
+ {
+ // Reset the completion to originals to allow retries from above
+ Packet->Completion.Send.OnSendCompletion = filterPacket->OnCompletion;
+ Packet->Completion.Send.SendCompletionContext = filterPacket->CompletionContext;
+ }
+
+ DPRINT_EXIT(NETVSC);
+
+ return ret;
+}
+
+static void
+RndisFilterOnSendCompletion(
+ void *Context)
+{
+ RNDIS_FILTER_PACKET *filterPacket = (RNDIS_FILTER_PACKET *)Context;
+
+ DPRINT_ENTER(NETVSC);
+
+ // Pass it back to the original handler
+ filterPacket->OnCompletion(filterPacket->CompletionContext);
+
+ DPRINT_EXIT(NETVSC);
+}
+
+
+static void
+RndisFilterOnSendRequestCompletion(
+ void *Context
+ )
+{
+ DPRINT_ENTER(NETVSC);
+
+ // Noop
+ DPRINT_EXIT(NETVSC);
+}