mirror of
https://github.com/ARMmbed/mbedtls.git
synced 2025-10-24 11:43:21 +08:00
Add knowledge of algorithms
Determine the category of operations supported by an algorithm based on its name. Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
This commit is contained in:
@@ -18,11 +18,21 @@ This module is entirely based on the PSA API.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import re
|
||||
from typing import Dict, Iterable, Optional, Pattern, Tuple
|
||||
|
||||
from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
|
||||
|
||||
|
||||
BLOCK_MAC_MODES = frozenset(['CBC_MAC', 'CMAC'])
|
||||
BLOCK_CIPHER_MODES = frozenset([
|
||||
'CTR', 'CFB', 'OFB', 'XTS', 'CCM_STAR_NO_TAG',
|
||||
'ECB_NO_PADDING', 'CBC_NO_PADDING', 'CBC_PKCS7',
|
||||
])
|
||||
BLOCK_AEAD_MODES = frozenset(['CCM', 'GCM'])
|
||||
|
||||
|
||||
class KeyType:
|
||||
"""Knowledge about a PSA key type."""
|
||||
|
||||
@@ -151,3 +161,144 @@ class KeyType:
|
||||
"""
|
||||
# This is just temporaly solution for the implicit usage flags.
|
||||
return re.match(self.KEY_TYPE_FOR_SIGNATURE[usage], self.name) is not None
|
||||
|
||||
|
||||
class AlgorithmCategory(enum.Enum):
|
||||
"""PSA algorithm categories."""
|
||||
# The numbers are aligned with the category bits in numerical values of
|
||||
# algorithms.
|
||||
HASH = 2
|
||||
MAC = 3
|
||||
CIPHER = 4
|
||||
AEAD = 5
|
||||
SIGN = 6
|
||||
ASYMMETRIC_ENCRYPTION = 7
|
||||
KEY_DERIVATION = 8
|
||||
KEY_AGREEMENT = 9
|
||||
PAKE = 10
|
||||
|
||||
def requires_key(self) -> bool:
|
||||
return self not in {self.HASH, self.KEY_DERIVATION}
|
||||
|
||||
|
||||
class AlgorithmNotRecognized(Exception):
|
||||
def __init__(self, expr: str) -> None:
|
||||
super().__init__('Algorithm not recognized: ' + expr)
|
||||
self.expr = expr
|
||||
|
||||
|
||||
class Algorithm:
|
||||
"""Knowledge about a PSA algorithm."""
|
||||
|
||||
@staticmethod
|
||||
def determine_base(expr: str) -> str:
|
||||
"""Return an expression for the "base" of the algorithm.
|
||||
|
||||
This strips off variants of algorithms such as MAC truncation.
|
||||
|
||||
This function does not attempt to detect invalid inputs.
|
||||
"""
|
||||
m = re.match(r'PSA_ALG_(?:'
|
||||
r'(?:TRUNCATED|AT_LEAST_THIS_LENGTH)_MAC|'
|
||||
r'AEAD_WITH_(?:SHORTENED|AT_LEAST_THIS_LENGTH)_TAG'
|
||||
r')\((.*),[^,]+\)\Z', expr)
|
||||
if m:
|
||||
expr = m.group(1)
|
||||
return expr
|
||||
|
||||
@staticmethod
|
||||
def determine_head(expr: str) -> str:
|
||||
"""Return the head of an algorithm expression.
|
||||
|
||||
The head is the first (outermost) constructor, without its PSA_ALG_
|
||||
prefix, and with some normalization of similar algorithms.
|
||||
"""
|
||||
m = re.match(r'PSA_ALG_(?:DETERMINISTIC_)?(\w+)', expr)
|
||||
if not m:
|
||||
raise AlgorithmNotRecognized(expr)
|
||||
head = m.group(1)
|
||||
if head == 'KEY_AGREEMENT':
|
||||
m = re.match(r'PSA_ALG_KEY_AGREEMENT\s*\(\s*PSA_ALG_(\w+)', expr)
|
||||
if not m:
|
||||
raise AlgorithmNotRecognized(expr)
|
||||
head = m.group(1)
|
||||
head = re.sub(r'_ANY\Z', r'', head)
|
||||
if re.match(r'ED[0-9]+PH\Z', head):
|
||||
head = 'EDDSA_PREHASH'
|
||||
return head
|
||||
|
||||
CATEGORY_FROM_HEAD = {
|
||||
'SHA': AlgorithmCategory.HASH,
|
||||
'SHAKE256_512': AlgorithmCategory.HASH,
|
||||
'MD': AlgorithmCategory.HASH,
|
||||
'RIPEMD': AlgorithmCategory.HASH,
|
||||
'ANY_HASH': AlgorithmCategory.HASH,
|
||||
'HMAC': AlgorithmCategory.MAC,
|
||||
'STREAM_CIPHER': AlgorithmCategory.CIPHER,
|
||||
'CHACHA20_POLY1305': AlgorithmCategory.AEAD,
|
||||
'DSA': AlgorithmCategory.SIGN,
|
||||
'ECDSA': AlgorithmCategory.SIGN,
|
||||
'EDDSA': AlgorithmCategory.SIGN,
|
||||
'PURE_EDDSA': AlgorithmCategory.SIGN,
|
||||
'RSA_PSS': AlgorithmCategory.SIGN,
|
||||
'RSA_PKCS1V15_SIGN': AlgorithmCategory.SIGN,
|
||||
'RSA_PKCS1V15_CRYPT': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
|
||||
'RSA_OAEP': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
|
||||
'HKDF': AlgorithmCategory.KEY_DERIVATION,
|
||||
'TLS12_PRF': AlgorithmCategory.KEY_DERIVATION,
|
||||
'TLS12_PSK_TO_MS': AlgorithmCategory.KEY_DERIVATION,
|
||||
'PBKDF': AlgorithmCategory.KEY_DERIVATION,
|
||||
'ECDH': AlgorithmCategory.KEY_AGREEMENT,
|
||||
'FFDH': AlgorithmCategory.KEY_AGREEMENT,
|
||||
# KEY_AGREEMENT(...) is a key derivation with a key agreement component
|
||||
'KEY_AGREEMENT': AlgorithmCategory.KEY_DERIVATION,
|
||||
'JPAKE': AlgorithmCategory.PAKE,
|
||||
}
|
||||
for x in BLOCK_MAC_MODES:
|
||||
CATEGORY_FROM_HEAD[x] = AlgorithmCategory.MAC
|
||||
for x in BLOCK_CIPHER_MODES:
|
||||
CATEGORY_FROM_HEAD[x] = AlgorithmCategory.CIPHER
|
||||
for x in BLOCK_AEAD_MODES:
|
||||
CATEGORY_FROM_HEAD[x] = AlgorithmCategory.AEAD
|
||||
|
||||
def determine_category(self, expr: str, head: str) -> AlgorithmCategory:
|
||||
"""Return the category of the given algorithm expression.
|
||||
|
||||
This function does not attempt to detect invalid inputs.
|
||||
"""
|
||||
prefix = head
|
||||
while prefix:
|
||||
if prefix in self.CATEGORY_FROM_HEAD:
|
||||
return self.CATEGORY_FROM_HEAD[prefix]
|
||||
if re.match(r'.*[0-9]\Z', prefix):
|
||||
prefix = re.sub(r'_*[0-9]+\Z', r'', prefix)
|
||||
else:
|
||||
prefix = re.sub(r'_*[^_]*\Z', r'', prefix)
|
||||
raise AlgorithmNotRecognized(expr)
|
||||
|
||||
@staticmethod
|
||||
def determine_wildcard(expr) -> bool:
|
||||
"""Whether the given algorithm expression is a wildcard.
|
||||
|
||||
This function does not attempt to detect invalid inputs.
|
||||
"""
|
||||
if re.search(r'\bPSA_ALG_ANY_HASH\b', expr):
|
||||
return True
|
||||
if re.search(r'_AT_LEAST_', expr):
|
||||
return True
|
||||
return False
|
||||
|
||||
def __init__(self, expr: str) -> None:
|
||||
"""Analyze an algorithm value.
|
||||
|
||||
The algorithm must be expressed as a C expression containing only
|
||||
calls to PSA algorithm constructor macros and numeric literals.
|
||||
|
||||
This class is only programmed to handle valid expressions. Invalid
|
||||
expressions may result in exceptions or in nonsensical results.
|
||||
"""
|
||||
self.expression = re.sub(r'\s+', r'', expr)
|
||||
self.base_expression = self.determine_base(self.expression)
|
||||
self.head = self.determine_head(self.base_expression)
|
||||
self.category = self.determine_category(self.base_expression, self.head)
|
||||
self.is_wildcard = self.determine_wildcard(self.expression)
|
||||
|
Reference in New Issue
Block a user