Skip to content

Commit ebb0e55

Browse files
authored
fix error happening when sqs message attributes are readonly (#8473)
* fix error happening when sqs message attributes are readonly * add test * same for batch requests * same for receive message requests * fix test assertion
1 parent ffe33cd commit ebb0e55

File tree

2 files changed

+92
-25
lines changed

2 files changed

+92
-25
lines changed

dd-java-agent/instrumentation/aws-java-sqs-1.0/src/main/java/datadog/trace/instrumentation/aws/v1/sqs/SqsInterceptor.java

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import com.amazonaws.AmazonWebServiceRequest;
1414
import com.amazonaws.handlers.RequestHandler2;
15+
import com.amazonaws.services.sqs.model.MessageAttributeValue;
1516
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
1617
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
1718
import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
@@ -22,7 +23,11 @@
2223
import datadog.trace.api.datastreams.DataStreamsContext;
2324
import datadog.trace.bootstrap.ContextStore;
2425
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
26+
import java.util.ArrayList;
27+
import java.util.HashMap;
2528
import java.util.LinkedHashMap;
29+
import java.util.List;
30+
import java.util.Map;
2631

2732
public class SqsInterceptor extends RequestHandler2 {
2833

@@ -42,9 +47,14 @@ public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request
4247

4348
Propagator dsmPropagator = Propagators.forConcern(DSM_CONCERN);
4449
Context context = newContext(request, queueUrl);
50+
// making a copy of the MessageAttributes before modifying them because they can be stored in
51+
// a kind of ImmutableMap
52+
Map<String, MessageAttributeValue> messageAttributes =
53+
new HashMap<>(smRequest.getMessageAttributes());
54+
dsmPropagator.inject(context, messageAttributes, SETTER);
4555
// note: modifying message attributes has to be done before marshalling, otherwise the changes
4656
// are not reflected in the actual request (and the MD5 check on send will fail).
47-
dsmPropagator.inject(context, smRequest.getMessageAttributes(), SETTER);
57+
smRequest.setMessageAttributes(messageAttributes);
4858
} else if (request instanceof SendMessageBatchRequest) {
4959
SendMessageBatchRequest smbRequest = (SendMessageBatchRequest) request;
5060

@@ -54,13 +64,18 @@ public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request
5464
Propagator dsmPropagator = Propagators.forConcern(DSM_CONCERN);
5565
Context context = newContext(request, queueUrl);
5666
for (SendMessageBatchRequestEntry entry : smbRequest.getEntries()) {
57-
dsmPropagator.inject(context, entry.getMessageAttributes(), SETTER);
67+
Map<String, MessageAttributeValue> messageAttributes =
68+
new HashMap<>(entry.getMessageAttributes());
69+
dsmPropagator.inject(context, messageAttributes, SETTER);
70+
entry.setMessageAttributes(messageAttributes);
5871
}
5972
} else if (request instanceof ReceiveMessageRequest) {
6073
ReceiveMessageRequest rmRequest = (ReceiveMessageRequest) request;
6174
if (rmRequest.getMessageAttributeNames().size() < 10
6275
&& !rmRequest.getMessageAttributeNames().contains(DATADOG_KEY)) {
63-
rmRequest.getMessageAttributeNames().add(DATADOG_KEY);
76+
List<String> attributeNames = new ArrayList<>(rmRequest.getMessageAttributeNames());
77+
attributeNames.add(DATADOG_KEY);
78+
rmRequest.setMessageAttributeNames(attributeNames);
6479
}
6580
}
6681
return request;

dd-java-agent/instrumentation/aws-java-sqs-1.0/src/test/groovy/SqsClientTest.groovy

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import com.amazonaws.client.builder.AwsClientBuilder
77
import com.amazonaws.services.sqs.AmazonSQSClientBuilder
88
import com.amazonaws.services.sqs.model.Message
99
import com.amazonaws.services.sqs.model.MessageAttributeValue
10+
import com.amazonaws.services.sqs.model.ReceiveMessageRequest
1011
import com.amazonaws.services.sqs.model.SendMessageRequest
12+
import com.google.common.collect.ImmutableMap
1113
import datadog.trace.agent.test.naming.VersionedNamingTestBase
1214
import datadog.trace.agent.test.utils.TraceUtils
1315
import datadog.trace.api.Config
@@ -87,9 +89,9 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
8789
def "trace details propagated via SQS system message attributes"() {
8890
setup:
8991
def client = AmazonSQSClientBuilder.standard()
90-
.withEndpointConfiguration(endpoint)
91-
.withCredentials(credentialsProvider)
92-
.build()
92+
.withEndpointConfiguration(endpoint)
93+
.withCredentials(credentialsProvider)
94+
.build()
9395
def queueUrl = client.createQueue('somequeue').queueUrl
9496
TEST_WRITER.clear()
9597

@@ -188,6 +190,56 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
188190
client.shutdown()
189191
}
190192

193+
@IgnoreIf({ !instance.isDataStreamsEnabled() })
194+
def "propagation even when message attributes are readonly"() {
195+
setup:
196+
def client = AmazonSQSClientBuilder.standard()
197+
.withEndpointConfiguration(endpoint)
198+
.withCredentials(credentialsProvider)
199+
.build()
200+
def queueUrl = client.createQueue('somequeue').queueUrl
201+
TEST_WRITER.clear()
202+
203+
when:
204+
TraceUtils.runUnderTrace('parent', {
205+
def my_attribute = new MessageAttributeValue()
206+
my_attribute.setStringValue("hello world")
207+
my_attribute.setDataType("String")
208+
def readonlyAttributes = ImmutableMap<String, MessageAttributeValue>.of("my_key", my_attribute)
209+
def req = new SendMessageRequest(queueUrl, 'sometext')
210+
req.setMessageAttributes(readonlyAttributes)
211+
client.sendMessage(req)
212+
})
213+
214+
TEST_DATA_STREAMS_WRITER.waitForGroups(1)
215+
216+
then:
217+
assertTraces(1) {
218+
trace(2) {
219+
basicSpan(it, "parent")
220+
span {
221+
serviceName expectedService("SQS", "SendMessage")
222+
operationName expectedOperation("SQS", "SendMessage")
223+
resourceName "SQS.SendMessage"
224+
spanType DDSpanTypes.HTTP_CLIENT
225+
errored false
226+
childOf(span(0))
227+
}
228+
}
229+
}
230+
231+
and:
232+
def recv = new ReceiveMessageRequest(queueUrl)
233+
recv.withMessageAttributeNames("my_key")
234+
def messages = client.receiveMessage(recv).messages
235+
236+
assert messages[0].messageAttributes.containsKey("my_key") // what we set initially
237+
assert messages[0].messageAttributes.containsKey("_datadog") // what was injected
238+
239+
cleanup:
240+
client.shutdown()
241+
}
242+
191243
@IgnoreIf({ instance.isDataStreamsEnabled() })
192244
def "trace details propagated via embedded SQS message attribute (string)"() {
193245
setup:
@@ -196,8 +248,8 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
196248
when:
197249
def message = new Message()
198250
message.addMessageAttributesEntry('_datadog', new MessageAttributeValue().withDataType('String').withStringValue(
199-
"{\"x-datadog-trace-id\": \"4948377316357291421\", \"x-datadog-parent-id\": \"6746998015037429512\", \"x-datadog-sampling-priority\": \"1\"}"
200-
))
251+
"{\"x-datadog-trace-id\": \"4948377316357291421\", \"x-datadog-parent-id\": \"6746998015037429512\", \"x-datadog-sampling-priority\": \"1\"}"
252+
))
201253
def messages = new TracingList([message], "http://localhost:${address.port}/000000000000/somequeue")
202254

203255
messages.forEach {/* consume to create message spans */ }
@@ -237,8 +289,8 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
237289
when:
238290
def message = new Message()
239291
message.addMessageAttributesEntry('_datadog', new MessageAttributeValue().withDataType('Binary').withBinaryValue(
240-
headerValue
241-
))
292+
headerValue
293+
))
242294
def messages = new TracingList([message], "http://localhost:${address.port}/000000000000/somequeue")
243295

244296
messages.forEach {/* consume to create message spans */ }
@@ -281,9 +333,9 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
281333
def "trace details propagated from SQS to JMS"() {
282334
setup:
283335
def client = AmazonSQSClientBuilder.standard()
284-
.withEndpointConfiguration(endpoint)
285-
.withCredentials(credentialsProvider)
286-
.build()
336+
.withEndpointConfiguration(endpoint)
337+
.withCredentials(credentialsProvider)
338+
.build()
287339

288340
def connectionFactory = new SQSConnectionFactory(new ProviderConfiguration(), client)
289341
def connection = connectionFactory.createConnection()
@@ -295,12 +347,12 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
295347

296348
when:
297349
def ddMsgAttribute = new MessageAttributeValue()
298-
.withBinaryValue(ByteBuffer.wrap("hello world".getBytes(Charset.defaultCharset())))
299-
.withDataType("Binary")
350+
.withBinaryValue(ByteBuffer.wrap("hello world".getBytes(Charset.defaultCharset())))
351+
.withDataType("Binary")
300352
connection.start()
301353
TraceUtils.runUnderTrace('parent') {
302354
client.sendMessage(new SendMessageRequest(queue.queueUrl, 'sometext')
303-
.withMessageAttributes([_datadog: ddMsgAttribute]))
355+
.withMessageAttributes([_datadog: ddMsgAttribute]))
304356
}
305357
def message = consumer.receive()
306358
consumer.receiveNoWait()
@@ -558,9 +610,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest {
558610
def "Data streams context extracted from message body"() {
559611
setup:
560612
def client = AmazonSQSClientBuilder.standard()
561-
.withEndpointConfiguration(endpoint)
562-
.withCredentials(credentialsProvider)
563-
.build()
613+
.withEndpointConfiguration(endpoint)
614+
.withCredentials(credentialsProvider)
615+
.build()
564616
def queueUrl = client.createQueue('somequeue').queueUrl
565617
TEST_WRITER.clear()
566618

@@ -588,9 +640,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest {
588640
def "Data streams context not extracted from message body when message attributes are not present"() {
589641
setup:
590642
def client = AmazonSQSClientBuilder.standard()
591-
.withEndpointConfiguration(endpoint)
592-
.withCredentials(credentialsProvider)
593-
.build()
643+
.withEndpointConfiguration(endpoint)
644+
.withCredentials(credentialsProvider)
645+
.build()
594646
def queueUrl = client.createQueue('somequeue').queueUrl
595647
TEST_WRITER.clear()
596648

@@ -619,9 +671,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest {
619671
def "Data streams context not extracted from message body when message is not a Json"() {
620672
setup:
621673
def client = AmazonSQSClientBuilder.standard()
622-
.withEndpointConfiguration(endpoint)
623-
.withCredentials(credentialsProvider)
624-
.build()
674+
.withEndpointConfiguration(endpoint)
675+
.withCredentials(credentialsProvider)
676+
.build()
625677
def queueUrl = client.createQueue('somequeue').queueUrl
626678
TEST_WRITER.clear()
627679

0 commit comments

Comments
 (0)