1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "remoting/host/win/wts_terminal_monitor.h"
11 #include "base/basictypes.h"
12 #include "base/files/file_path.h"
13 #include "base/lazy_instance.h"
14 #include "base/native_library.h"
15 #include "base/scoped_native_library.h"
16 #include "base/utf_string_conversions.h"
17 #include "net/base/ip_endpoint.h"
21 // Used to query the endpoint of an attached RDP client.
22 const WINSTATIONINFOCLASS kWinStationRemoteAddress
=
23 static_cast<WINSTATIONINFOCLASS
>(29);
25 // WinStationRemoteAddress information class returns the following structure.
26 // Note that its layout is different from sockaddr_in/sockaddr_in6. For
27 // instance both |ipv4| and |ipv6| structures are 4 byte aligned so there is
28 // additional 2 byte padding after |sin_family|.
29 struct RemoteAddress
{
30 unsigned short sin_family
;
46 // Loads winsta.dll dynamically and resolves the address of
47 // the winsta!WinStationQueryInformationW() function.
53 // Returns the address and port of the RDP client attached to |session_id|.
54 bool GetRemoteAddress(uint32 session_id
, RemoteAddress
* address
);
57 // Handle of dynamically loaded winsta.dll.
58 base::ScopedNativeLibrary winsta_
;
60 // Points to winsta!WinStationQueryInformationW().
61 PWINSTATIONQUERYINFORMATIONW win_station_query_information_
;
63 DISALLOW_COPY_AND_ASSIGN(WinstaLoader
);
66 static base::LazyInstance
<WinstaLoader
> g_winsta
= LAZY_INSTANCE_INITIALIZER
;
68 WinstaLoader::WinstaLoader() :
69 winsta_(base::FilePath(base::GetNativeLibraryName(UTF8ToUTF16("winsta")))) {
71 // Resolve the function pointer.
72 win_station_query_information_
=
73 static_cast<PWINSTATIONQUERYINFORMATIONW
>(
74 winsta_
.GetFunctionPointer("WinStationQueryInformationW"));
77 WinstaLoader::~WinstaLoader() {
80 bool WinstaLoader::GetRemoteAddress(uint32 session_id
, RemoteAddress
* address
) {
82 return win_station_query_information_(WTS_CURRENT_SERVER_HANDLE
,
84 kWinStationRemoteAddress
,
94 // Session id that does not represent any session.
95 const uint32 kInvalidSessionId
= 0xffffffffu
;
97 WtsTerminalMonitor::~WtsTerminalMonitor() {
101 bool WtsTerminalMonitor::GetEndpointForSessionId(uint32 session_id
,
102 net::IPEndPoint
* endpoint
) {
103 // Fast path for the case when |session_id| is currently attached to
104 // the physical console.
105 if (session_id
== WTSGetActiveConsoleSessionId()) {
106 *endpoint
= net::IPEndPoint();
110 RemoteAddress address
;
111 // WinStationQueryInformationW() fails if no RDP client is attached to
113 if (!g_winsta
.Get().GetRemoteAddress(session_id
, &address
))
116 // Convert the RemoteAddress structure into sockaddr_in/sockaddr_in6.
117 switch (address
.sin_family
) {
119 sockaddr_in ipv4
= { 0 };
120 ipv4
.sin_family
= AF_INET
;
121 ipv4
.sin_port
= address
.ipv4
.sin_port
;
122 ipv4
.sin_addr
.S_un
.S_addr
= address
.ipv4
.in_addr
;
123 return endpoint
->FromSockAddr(
124 reinterpret_cast<struct sockaddr
*>(&ipv4
), sizeof(ipv4
));
128 sockaddr_in6 ipv6
= { 0 };
129 ipv6
.sin6_family
= AF_INET6
;
130 ipv6
.sin6_port
= address
.ipv6
.sin6_port
;
131 ipv6
.sin6_flowinfo
= address
.ipv6
.sin6_flowinfo
;
132 memcpy(&ipv6
.sin6_addr
, address
.ipv6
.sin6_addr
, sizeof(ipv6
.sin6_addr
));
133 ipv6
.sin6_scope_id
= address
.ipv6
.sin6_scope_id
;
134 return endpoint
->FromSockAddr(
135 reinterpret_cast<struct sockaddr
*>(&ipv6
), sizeof(ipv6
));
144 uint32
WtsTerminalMonitor::GetSessionIdForEndpoint(
145 const net::IPEndPoint
& client_endpoint
) {
146 // Use the fast path if the caller wants to get id of the session attached to
147 // the physical console.
148 if (client_endpoint
== net::IPEndPoint())
149 return WTSGetActiveConsoleSessionId();
151 // Enumerate all sessions and try to match the client endpoint.
152 WTS_SESSION_INFO
* session_info
;
153 DWORD session_info_count
;
154 if (!WTSEnumerateSessions(WTS_CURRENT_SERVER_HANDLE
, 0, 1, &session_info
,
155 &session_info_count
)) {
156 LOG_GETLASTERROR(ERROR
) << "Failed to enumerate all sessions";
157 return kInvalidSessionId
;
159 for (DWORD i
= 0; i
< session_info_count
; ++i
) {
160 net::IPEndPoint endpoint
;
161 if (GetEndpointForSessionId(session_info
[i
].SessionId
, &endpoint
) &&
162 endpoint
== client_endpoint
) {
163 uint32 session_id
= session_info
[i
].SessionId
;
164 WTSFreeMemory(session_info
);
169 // |client_endpoint| is not associated with any session.
170 WTSFreeMemory(session_info
);
171 return kInvalidSessionId
;
174 WtsTerminalMonitor::WtsTerminalMonitor() {
177 } // namespace remoting