Dispose of NPOject in npobject_identity_test.cc
[chromium-blink-merge.git] / chrome_frame / http_negotiate.cc
blob7ca3cead1b087ab92f729dcff720280534e8248c
1 // Copyright (c) 2011 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "chrome_frame/http_negotiate.h"
7 #include <atlbase.h>
8 #include <atlcom.h>
9 #include <htiframe.h>
11 #include "base/logging.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/string_util.h"
14 #include "base/stringprintf.h"
15 #include "base/utf_string_conversions.h"
16 #include "chrome_frame/bho.h"
17 #include "chrome_frame/exception_barrier.h"
18 #include "chrome_frame/html_utils.h"
19 #include "chrome_frame/urlmon_moniker.h"
20 #include "chrome_frame/urlmon_url_request.h"
21 #include "chrome_frame/utils.h"
22 #include "chrome_frame/vtable_patch_manager.h"
23 #include "net/http/http_response_headers.h"
24 #include "net/http/http_util.h"
26 bool HttpNegotiatePatch::modify_user_agent_ = true;
27 const char kUACompatibleHttpHeader[] = "x-ua-compatible";
28 const char kLowerCaseUserAgent[] = "user-agent";
30 // From the latest urlmon.h. Symbol name prepended with LOCAL_ to
31 // avoid conflict (and therefore build errors) for those building with
32 // a newer Windows SDK.
33 // TODO(robertshield): Remove this once we update our SDK version.
34 const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54;
36 static const int kHttpNegotiateBeginningTransactionIndex = 3;
38 BEGIN_VTABLE_PATCHES(IHttpNegotiate)
39 VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex,
40 HttpNegotiatePatch::BeginningTransaction)
41 END_VTABLE_PATCHES()
43 namespace {
45 class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>,
46 public IBindStatusCallback {
47 public:
48 BEGIN_COM_MAP(SimpleBindStatusCallback)
49 COM_INTERFACE_ENTRY(IBindStatusCallback)
50 END_COM_MAP()
52 // IBindStatusCallback implementation
53 STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) {
54 return E_NOTIMPL;
57 STDMETHOD(GetPriority)(LONG* priority) {
58 return E_NOTIMPL;
60 STDMETHOD(OnLowResource)(DWORD reserved) {
61 return E_NOTIMPL;
64 STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress,
65 ULONG status_code, LPCWSTR status_text) {
66 return E_NOTIMPL;
68 STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) {
69 return E_NOTIMPL;
72 STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) {
73 return E_NOTIMPL;
76 STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc,
77 STGMEDIUM* storage) {
78 return E_NOTIMPL;
80 STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) {
81 return E_NOTIMPL;
84 } // end namespace
86 std::string AppendCFUserAgentString(LPCWSTR headers,
87 LPCWSTR additional_headers) {
88 using net::HttpUtil;
90 std::string ascii_headers;
91 if (additional_headers) {
92 ascii_headers = WideToASCII(additional_headers);
95 // Extract "User-Agent" from |additional_headers| or |headers|.
96 HttpUtil::HeadersIterator headers_iterator(ascii_headers.begin(),
97 ascii_headers.end(), "\r\n");
98 std::string user_agent_value;
99 if (headers_iterator.AdvanceTo(kLowerCaseUserAgent)) {
100 user_agent_value = headers_iterator.values();
101 } else if (headers != NULL) {
102 // See if there's a user-agent header specified in the original headers.
103 std::string original_headers(WideToASCII(headers));
104 HttpUtil::HeadersIterator original_it(original_headers.begin(),
105 original_headers.end(), "\r\n");
106 if (original_it.AdvanceTo(kLowerCaseUserAgent))
107 user_agent_value = original_it.values();
110 // Use the default "User-Agent" if none was provided.
111 if (user_agent_value.empty())
112 user_agent_value = http_utils::GetDefaultUserAgent();
114 // Now add chromeframe to it.
115 user_agent_value = http_utils::AddChromeFrameToUserAgentValue(
116 user_agent_value);
118 // Build new headers, skip the existing user agent value from
119 // existing headers.
120 std::string new_headers;
121 headers_iterator.Reset();
122 while (headers_iterator.GetNext()) {
123 std::string name(headers_iterator.name());
124 if (!LowerCaseEqualsASCII(name, kLowerCaseUserAgent)) {
125 new_headers += name + ": " + headers_iterator.values() + "\r\n";
129 new_headers += "User-Agent: " + user_agent_value;
130 new_headers += "\r\n";
131 return new_headers;
134 std::string ReplaceOrAddUserAgent(LPCWSTR headers,
135 const std::string& user_agent_value) {
136 using net::HttpUtil;
138 std::string new_headers;
139 if (headers) {
140 std::string ascii_headers(WideToASCII(headers));
142 // Extract "User-Agent" from the headers.
143 HttpUtil::HeadersIterator headers_iterator(ascii_headers.begin(),
144 ascii_headers.end(), "\r\n");
146 // Build new headers, skip the existing user agent value from
147 // existing headers.
148 while (headers_iterator.GetNext()) {
149 std::string name(headers_iterator.name());
150 if (!LowerCaseEqualsASCII(name, kLowerCaseUserAgent)) {
151 new_headers += name + ": " + headers_iterator.values() + "\r\n";
155 new_headers += "User-Agent: " + user_agent_value;
156 new_headers += "\r\n";
157 return new_headers;
160 HttpNegotiatePatch::HttpNegotiatePatch() {
163 HttpNegotiatePatch::~HttpNegotiatePatch() {
166 // static
167 bool HttpNegotiatePatch::Initialize() {
168 if (IS_PATCHED(IHttpNegotiate)) {
169 DLOG(WARNING) << __FUNCTION__ << " called more than once.";
170 return true;
172 // Use our SimpleBindStatusCallback class as we need a temporary object that
173 // implements IBindStatusCallback.
174 CComObjectStackEx<SimpleBindStatusCallback> request;
175 base::win::ScopedComPtr<IBindCtx> bind_ctx;
176 HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive());
177 DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx";
178 if (bind_ctx) {
179 base::win::ScopedComPtr<IUnknown> bscb_holder;
180 bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive());
181 if (bscb_holder) {
182 hr = PatchHttpNegotiate(bscb_holder);
183 } else {
184 NOTREACHED() << "Failed to get _BSCB_Holder_";
185 hr = E_UNEXPECTED;
187 bind_ctx.Release();
190 return SUCCEEDED(hr);
193 // static
194 void HttpNegotiatePatch::Uninitialize() {
195 vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo);
198 // static
199 HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) {
200 DCHECK(to_patch);
201 DCHECK_IS_NOT_PATCHED(IHttpNegotiate);
203 base::win::ScopedComPtr<IHttpNegotiate> http;
204 HRESULT hr = http.QueryFrom(to_patch);
205 if (FAILED(hr)) {
206 hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive());
209 if (http) {
210 hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo);
211 DLOG_IF(ERROR, FAILED(hr))
212 << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr);
213 } else {
214 DLOG(WARNING)
215 << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr);
217 return hr;
220 // static
221 HRESULT HttpNegotiatePatch::BeginningTransaction(
222 IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me,
223 LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) {
224 DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers;
226 HRESULT hr = original(me, url, headers, reserved, additional_headers);
228 if (FAILED(hr)) {
229 DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error";
230 return hr;
232 if (modify_user_agent_) {
233 std::string updated_headers;
234 if (IsGcfDefaultRenderer() &&
235 RendererTypeForUrl(url) == RENDERER_TYPE_CHROME_DEFAULT_RENDERER) {
236 // Replace the user-agent header with Chrome's.
237 updated_headers = ReplaceOrAddUserAgent(*additional_headers,
238 http_utils::GetChromeUserAgent());
239 } else {
240 updated_headers = AppendCFUserAgentString(headers, *additional_headers);
242 *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc(
243 *additional_headers,
244 (updated_headers.length() + 1) * sizeof(wchar_t)));
245 lstrcpyW(*additional_headers, ASCIIToWide(updated_headers).c_str());
246 } else {
247 // TODO(erikwright): Remove the user agent if it is present (i.e., because
248 // of PostPlatform setting in the registry).
250 return S_OK;