Commit 6ad98405 by olly Committed by Oliver Woodman

Avoid providing invalid responses to MediaDrm

MediaDrm.provideXResponse methods only accept the response
corresponding to the most recent MediaDrm.getXRequest call.
Previously, our code allowed the following incorrect call
sequence:

a = getKeyRequest
b = getKeyRequest
provideKeyResponse(responseFor(a));

This would occur in the edge case of a second key request
being triggered whilst the first was still in flight. The
provideKeyResponse call would then fail.

This change fixes the problem by treating responseFor(a)
as stale. Note that a slightly better fix would be to
defer calling getKeyRequest the second time until after
processing the response corresponding to the first one,
however this is significantly harder to implement, and is
probably not worth it for what should be an edge case.

-------------
Created by MOE: https://github.com/google/moe
MOE_MIGRATED_REVID=203481685
parent a50d31a7
...@@ -97,6 +97,9 @@ import java.util.UUID; ...@@ -97,6 +97,9 @@ import java.util.UUID;
private byte[] sessionId; private byte[] sessionId;
private byte[] offlineLicenseKeySetId; private byte[] offlineLicenseKeySetId;
private Object currentKeyRequest;
private Object currentProvisionRequest;
/** /**
* Instantiates a new DRM session. * Instantiates a new DRM session.
* *
...@@ -171,6 +174,8 @@ import java.util.UUID; ...@@ -171,6 +174,8 @@ import java.util.UUID;
requestHandlerThread = null; requestHandlerThread = null;
mediaCrypto = null; mediaCrypto = null;
lastException = null; lastException = null;
currentKeyRequest = null;
currentProvisionRequest = null;
if (sessionId != null) { if (sessionId != null) {
mediaDrm.closeSession(sessionId); mediaDrm.closeSession(sessionId);
sessionId = null; sessionId = null;
...@@ -215,8 +220,8 @@ import java.util.UUID; ...@@ -215,8 +220,8 @@ import java.util.UUID;
// Provisioning implementation. // Provisioning implementation.
public void provision() { public void provision() {
ProvisionRequest request = mediaDrm.getProvisionRequest(); currentProvisionRequest = mediaDrm.getProvisionRequest();
postRequestHandler.obtainMessage(MSG_PROVISION, request, true).sendToTarget(); postRequestHandler.post(MSG_PROVISION, currentProvisionRequest, /* allowRetry= */ true);
} }
public void onProvisionCompleted() { public void onProvisionCompleted() {
...@@ -289,11 +294,12 @@ import java.util.UUID; ...@@ -289,11 +294,12 @@ import java.util.UUID;
return false; return false;
} }
private void onProvisionResponse(Object response) { private void onProvisionResponse(Object request, Object response) {
if (state != STATE_OPENING && !isOpen()) { if (request != currentProvisionRequest || (state != STATE_OPENING && !isOpen())) {
// This event is stale. // This event is stale.
return; return;
} }
currentProvisionRequest = null;
if (response instanceof Exception) { if (response instanceof Exception) {
provisioningManager.onProvisionError((Exception) response); provisioningManager.onProvisionError((Exception) response);
...@@ -383,20 +389,21 @@ import java.util.UUID; ...@@ -383,20 +389,21 @@ import java.util.UUID;
licenseServerUrl = schemeData.licenseServerUrl; licenseServerUrl = schemeData.licenseServerUrl;
} }
try { try {
KeyRequest request = KeyRequest mediaDrmKeyRequest =
mediaDrm.getKeyRequest(scope, initData, mimeType, type, optionalKeyRequestParameters); mediaDrm.getKeyRequest(scope, initData, mimeType, type, optionalKeyRequestParameters);
Pair<KeyRequest, String> arguments = Pair.create(request, licenseServerUrl); currentKeyRequest = Pair.create(mediaDrmKeyRequest, licenseServerUrl);
postRequestHandler.obtainMessage(MSG_KEYS, arguments, allowRetry).sendToTarget(); postRequestHandler.post(MSG_KEYS, currentKeyRequest, allowRetry);
} catch (Exception e) { } catch (Exception e) {
onKeysError(e); onKeysError(e);
} }
} }
private void onKeyResponse(Object response) { private void onKeyResponse(Object request, Object response) {
if (!isOpen()) { if (request != currentKeyRequest || !isOpen()) {
// This event is stale. // This event is stale.
return; return;
} }
currentKeyRequest = null;
if (response instanceof Exception) { if (response instanceof Exception) {
onKeysError((Exception) response); onKeysError((Exception) response);
...@@ -461,12 +468,15 @@ import java.util.UUID; ...@@ -461,12 +468,15 @@ import java.util.UUID;
@Override @Override
public void handleMessage(Message msg) { public void handleMessage(Message msg) {
Pair<?, ?> requestAndResponse = (Pair<?, ?>) msg.obj;
Object request = requestAndResponse.first;
Object response = requestAndResponse.second;
switch (msg.what) { switch (msg.what) {
case MSG_PROVISION: case MSG_PROVISION:
onProvisionResponse(msg.obj); onProvisionResponse(request, response);
break; break;
case MSG_KEYS: case MSG_KEYS:
onKeyResponse(msg.obj); onKeyResponse(request, response);
break; break;
default: default:
break; break;
...@@ -483,23 +493,27 @@ import java.util.UUID; ...@@ -483,23 +493,27 @@ import java.util.UUID;
super(backgroundLooper); super(backgroundLooper);
} }
Message obtainMessage(int what, Object object, boolean allowRetry) { void post(int what, Object request, boolean allowRetry) {
return obtainMessage(what, allowRetry ? 1 : 0 /* allow retry*/, 0 /* error count */, int allowRetryInt = allowRetry ? 1 : 0;
object); int errorCount = 0;
obtainMessage(what, allowRetryInt, errorCount, request).sendToTarget();
} }
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void handleMessage(Message msg) { public void handleMessage(Message msg) {
Object request = msg.obj;
Object response; Object response;
try { try {
switch (msg.what) { switch (msg.what) {
case MSG_PROVISION: case MSG_PROVISION:
response = callback.executeProvisionRequest(uuid, (ProvisionRequest) msg.obj); response = callback.executeProvisionRequest(uuid, (ProvisionRequest) request);
break; break;
case MSG_KEYS: case MSG_KEYS:
Pair<KeyRequest, String> arguments = (Pair<KeyRequest, String>) msg.obj; Pair<KeyRequest, String> keyRequest = (Pair<KeyRequest, String>) request;
response = callback.executeKeyRequest(uuid, arguments.first, arguments.second); KeyRequest mediaDrmKeyRequest = keyRequest.first;
String licenseServerUrl = keyRequest.second;
response = callback.executeKeyRequest(uuid, mediaDrmKeyRequest, licenseServerUrl);
break; break;
default: default:
throw new RuntimeException(); throw new RuntimeException();
...@@ -510,7 +524,7 @@ import java.util.UUID; ...@@ -510,7 +524,7 @@ import java.util.UUID;
} }
response = e; response = e;
} }
postResponseHandler.obtainMessage(msg.what, response).sendToTarget(); postResponseHandler.obtainMessage(msg.what, Pair.create(request, response)).sendToTarget();
} }
private boolean maybeRetryRequest(Message originalMsg) { private boolean maybeRetryRequest(Message originalMsg) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment