SRSIMDHelpers.m 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. //
  2. // Copyright (c) 2016-present, Facebook, Inc.
  3. // All rights reserved.
  4. //
  5. // This source code is licensed under the BSD-style license found in the
  6. // LICENSE file in the root directory of this source tree. An additional grant
  7. // of patent rights can be found in the PATENTS file in the same directory.
  8. //
  9. #import "SRSIMDHelpers.h"
  10. typedef uint8_t uint8x32_t __attribute__((vector_size(32)));
  11. static void SRMaskBytesManual(uint8_t *bytes, size_t length, uint8_t *maskKey) {
  12. for (size_t i = 0; i < length; i++) {
  13. bytes[i] = bytes[i] ^ maskKey[i % sizeof(uint32_t)];
  14. }
  15. }
  16. /**
  17. Right-shift the elements of a vector, circularly.
  18. @param vector The vector to circular shift.
  19. @param by The number of elements to shift by.
  20. @return A shifted vector.
  21. */
  22. static uint8x32_t SRShiftVector(uint8x32_t vector, size_t by) {
  23. uint8x32_t vectorCopy = vector;
  24. by = by % _Alignof(uint8x32_t);
  25. uint8_t *vectorPointer = (uint8_t *)&vector;
  26. uint8_t *vectorCopyPointer = (uint8_t *)&vectorCopy;
  27. memmove(vectorPointer + by, vectorPointer, sizeof(vector) - by);
  28. memcpy(vectorPointer, vectorCopyPointer + (sizeof(vector) - by), by);
  29. return vector;
  30. }
  31. void SRMaskBytesSIMD(uint8_t *bytes, size_t length, uint8_t *maskKey) {
  32. size_t alignmentBytes = _Alignof(uint8x32_t) - ((uintptr_t)bytes % _Alignof(uint8x32_t));
  33. if (alignmentBytes == _Alignof(uint8x32_t)) {
  34. alignmentBytes = 0;
  35. }
  36. // If the number of bytes that can be processed after aligning is
  37. // less than the number of bytes we can put into a vector,
  38. // then there's no work to do with SIMD, just call the manual version.
  39. if (alignmentBytes > length || (length - alignmentBytes) < sizeof(uint8x32_t)) {
  40. SRMaskBytesManual(bytes, length, maskKey);
  41. return;
  42. }
  43. size_t vectorLength = (length - alignmentBytes) / sizeof(uint8x32_t);
  44. size_t manualStartOffset = alignmentBytes + (vectorLength * sizeof(uint8x32_t));
  45. size_t manualLength = length - manualStartOffset;
  46. uint8x32_t *vector = (uint8x32_t *)(bytes + alignmentBytes);
  47. uint8x32_t maskVector = { };
  48. memset_pattern4(&maskVector, maskKey, sizeof(uint8x32_t));
  49. maskVector = SRShiftVector(maskVector, alignmentBytes);
  50. SRMaskBytesManual(bytes, alignmentBytes, maskKey);
  51. for (size_t vectorIndex = 0; vectorIndex < vectorLength; vectorIndex++) {
  52. vector[vectorIndex] = vector[vectorIndex] ^ maskVector;
  53. }
  54. // Use the shifted mask for the final manual part.
  55. SRMaskBytesManual(bytes + manualStartOffset, manualLength, (uint8_t *) &maskVector);
  56. }