Coverage for tropicsquare / ports / micropython / aesgcm.py: 0%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-03-27 21:24 +0000

1import ucryptolib 

2 

3class AESGCM: 

4 def __init__(self, key): 

5 self.key = key 

6 self._aes = ucryptolib.aes(key, 1) # ECB mode 

7 self.H = self._encrypt_block(b'\x00' * 16) 

8 

9 

10 def encrypt(self, nonce, data, associated_data): 

11 if len(nonce) != 12: 

12 raise ValueError("Nonce must be 12 bytes") 

13 # Compute J0 as specified in GCM for 96-bit IVs. 

14 J0 = nonce + b'\x00\x00\x00\x01' 

15 # Encryption uses counter blocks starting at inc32(J0) 

16 counter = self._inc32(J0) 

17 ciphertext = b"" 

18 for i in range(0, len(data), 16): 

19 block = data[i:i+16] 

20 keystream = self._encrypt_block(counter) 

21 ct_block = bytes(a ^ b for a, b in zip(block, keystream)) 

22 ciphertext += ct_block 

23 counter = self._inc32(counter) 

24 S = self._ghash(associated_data, ciphertext) 

25 tag = bytes(a ^ b for a, b in zip(self._encrypt_block(J0), S)) 

26 return ciphertext+tag 

27 

28 

29 def decrypt(self, nonce, data, associated_data): 

30 if len(nonce) != 12: 

31 raise ValueError("Nonce must be 12 bytes") 

32 

33 ciphertext, tag = data[:-16], data[-16:] 

34 

35 J0 = nonce + b'\x00\x00\x00\x01' 

36 S = self._ghash(associated_data, ciphertext) 

37 computed_tag = bytes(a ^ b for a, b in zip(self._encrypt_block(J0), S)) 

38 if computed_tag != tag: 

39 raise ValueError("Invalid tag! Authentication failed.") 

40 counter = self._inc32(J0) 

41 plaintext = b"" 

42 for i in range(0, len(ciphertext), 16): 

43 block = ciphertext[i:i+16] 

44 keystream = self._encrypt_block(counter) 

45 pt_block = bytes(a ^ b for a, b in zip(block, keystream)) 

46 plaintext += pt_block 

47 counter = self._inc32(counter) 

48 return plaintext 

49 

50 

51 def _encrypt_block(self, block): 

52 if len(block) != 16: 

53 raise ValueError("Block must be 16 bytes") 

54 return self._aes.encrypt(block) 

55 

56 

57 def _gf_mult(self, X, Y): 

58 R = 0xe1000000000000000000000000000000 

59 Z = 0 

60 V = Y 

61 for i in range(128): 

62 if (X >> (127 - i)) & 1: 

63 Z ^= V 

64 if V & 1: 

65 V = (V >> 1) ^ R 

66 else: 

67 V >>= 1 

68 return Z 

69 

70 

71 def _ghash(self, aad, ciphertext): 

72 H_int = int.from_bytes(self.H, "big") 

73 X = 0 

74 

75 # Process AAD 

76 for i in range(0, len(aad), 16): 

77 block = aad[i:i+16] 

78 if len(block) < 16: 

79 block += b'\x00' * (16 - len(block)) 

80 X = self._gf_mult(X ^ int.from_bytes(block, "big"), H_int) 

81 

82 # Process ciphertext 

83 for i in range(0, len(ciphertext), 16): 

84 block = ciphertext[i:i+16] 

85 if len(block) < 16: 

86 block += b'\x00' * (16 - len(block)) 

87 X = self._gf_mult(X ^ int.from_bytes(block, "big"), H_int) 

88 

89 # Process length block: 64-bit lengths of AAD and ciphertext (in bits) 

90 aad_bits = len(aad) * 8 

91 ct_bits = len(ciphertext) * 8 

92 L = aad_bits.to_bytes(8, "big") + ct_bits.to_bytes(8, "big") 

93 X = self._gf_mult(X ^ int.from_bytes(L, "big"), H_int) 

94 

95 return X.to_bytes(16, "big") 

96 

97 

98 def _inc32(self, block): 

99 counter = int.from_bytes(block[12:], "big") 

100 counter = (counter + 1) & 0xffffffff 

101 return block[:12] + counter.to_bytes(4, "big")