Add static type hints

This makes the code easier to understand and navigate, and also detected a few of bugs:

1. Missing brackets on e.upper. (Fixed)
2. Not strictly related to types, but a lot of the regexes were not raw strings and therefore contained invalid escape sequences. Python prints a warning about these in recent versions. (Fixed)
3. Expression in `process_pseudo_instructions()` that is always false. (Not fixed)
4. Missing definition of `log_and_exit()`. (Fixed)

This is validated via pre-commit in CI.
This commit is contained in:
Tim Hutt
2024-10-29 21:49:07 +00:00
parent bd5e598abf
commit 284a5fa0f7
12 changed files with 236 additions and 210 deletions

View File

@@ -27,8 +27,7 @@ repos:
# hooks:
# - id: pylint
# TODO: Enable this when types are added.
# - repo: https://github.com/RobertCraigie/pyright-python
# rev: v1.1.383
# hooks:
# - id: pyright
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.383
hooks:
- id: pyright

View File

@@ -8,7 +8,7 @@ pp = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=logging.INFO, format="%(levelname)s:: %(message)s")
def make_c(instr_dict):
def make_c(instr_dict: InstrDict):
mask_match_str = ""
declare_insn_str = ""
for i in instr_dict:

View File

@@ -10,7 +10,7 @@ pp = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=logging.INFO, format="%(levelname)s:: %(message)s")
def make_chisel(instr_dict, spinal_hdl=False):
def make_chisel(instr_dict: InstrDict, spinal_hdl: bool = False):
chisel_names = ""
cause_names_str = ""
@@ -31,7 +31,7 @@ def make_chisel(instr_dict, spinal_hdl=False):
elif "rv_" in e:
e_format = e.replace("rv_", "").upper()
else:
e_format = e.upper
e_format = e.upper()
chisel_names += f' val {e_format+"Type"} = Map(\n'
for instr in e_instrs:
tmp_instr_name = '"' + instr.upper().replace(".", "_") + '"'

View File

@@ -1,6 +1,7 @@
import csv
import re
# TODO: The constants in this file should be in all caps.
overlapping_extensions = {
"rv_zcmt": {"rv_c_d"},
"rv_zcmp": {"rv_c_d"},
@@ -21,29 +22,29 @@ isa_regex = re.compile(
# regex to find <msb>..<lsb>=<val> patterns in instruction
fixed_ranges = re.compile(
"\s*(?P<msb>\d+.?)\.\.(?P<lsb>\d+.?)\s*=\s*(?P<val>\d[\w]*)[\s$]*", re.M
r"\s*(?P<msb>\d+.?)\.\.(?P<lsb>\d+.?)\s*=\s*(?P<val>\d[\w]*)[\s$]*", re.M
)
# regex to find <lsb>=<val> patterns in instructions
# single_fixed = re.compile('\s+(?P<lsb>\d+)=(?P<value>[\w\d]*)[\s$]*', re.M)
single_fixed = re.compile("(?:^|[\s])(?P<lsb>\d+)=(?P<value>[\w]*)((?=\s|$))", re.M)
single_fixed = re.compile(r"(?:^|[\s])(?P<lsb>\d+)=(?P<value>[\w]*)((?=\s|$))", re.M)
# regex to find the overloading condition variable
var_regex = re.compile("(?P<var>[a-zA-Z][\w\d]*)\s*=\s*.*?[\s$]*", re.M)
var_regex = re.compile(r"(?P<var>[a-zA-Z][\w\d]*)\s*=\s*.*?[\s$]*", re.M)
# regex for pseudo op instructions returns the dependent filename, dependent
# instruction, the pseudo op name and the encoding string
pseudo_regex = re.compile(
"^\$pseudo_op\s+(?P<filename>rv[\d]*_[\w].*)::\s*(?P<orig_inst>.*?)\s+(?P<pseudo_inst>.*?)\s+(?P<overload>.*)$",
r"^\$pseudo_op\s+(?P<filename>rv[\d]*_[\w].*)::\s*(?P<orig_inst>.*?)\s+(?P<pseudo_inst>.*?)\s+(?P<overload>.*)$",
re.M,
)
imported_regex = re.compile(
"^\s*\$import\s*(?P<extension>.*)\s*::\s*(?P<instruction>.*)", re.M
r"^\s*\$import\s*(?P<extension>.*)\s*::\s*(?P<instruction>.*)", re.M
)
def read_csv(filename):
def read_csv(filename: str):
"""
Reads a CSV file and returns a list of tuples.
Each tuple contains an integer value (from the first column) and a string (from the second column).
@@ -79,126 +80,99 @@ arg_lut["c_mop_t"] = (10, 8)
# dictionary containing the mapping of the argument to the what the fields in
# the latex table should be
latex_mapping = {}
latex_mapping["imm12"] = "imm[11:0]"
latex_mapping["rs1"] = "rs1"
latex_mapping["rs2"] = "rs2"
latex_mapping["rd"] = "rd"
latex_mapping["imm20"] = "imm[31:12]"
latex_mapping["bimm12hi"] = "imm[12$\\vert$10:5]"
latex_mapping["bimm12lo"] = "imm[4:1$\\vert$11]"
latex_mapping["imm12hi"] = "imm[11:5]"
latex_mapping["imm12lo"] = "imm[4:0]"
latex_mapping["jimm20"] = "imm[20$\\vert$10:1$\\vert$11$\\vert$19:12]"
latex_mapping["zimm"] = "uimm"
latex_mapping["shamtw"] = "shamt"
latex_mapping["shamtd"] = "shamt"
latex_mapping["shamtq"] = "shamt"
latex_mapping["rd_p"] = "rd\\,$'$"
latex_mapping["rs1_p"] = "rs1\\,$'$"
latex_mapping["rs2_p"] = "rs2\\,$'$"
latex_mapping["rd_rs1_n0"] = "rd/rs$\\neq$0"
latex_mapping["rd_rs1_p"] = "rs1\\,$'$/rs2\\,$'$"
latex_mapping["c_rs2"] = "rs2"
latex_mapping["c_rs2_n0"] = "rs2$\\neq$0"
latex_mapping["rd_n0"] = "rd$\\neq$0"
latex_mapping["rs1_n0"] = "rs1$\\neq$0"
latex_mapping["c_rs1_n0"] = "rs1$\\neq$0"
latex_mapping["rd_rs1"] = "rd/rs1"
latex_mapping["zimm6hi"] = "uimm[5]"
latex_mapping["zimm6lo"] = "uimm[4:0]"
latex_mapping["c_nzuimm10"] = "nzuimm[5:4$\\vert$9:6$\\vert$2$\\vert$3]"
latex_mapping["c_uimm7lo"] = "uimm[2$\\vert$6]"
latex_mapping["c_uimm7hi"] = "uimm[5:3]"
latex_mapping["c_uimm8lo"] = "uimm[7:6]"
latex_mapping["c_uimm8hi"] = "uimm[5:3]"
latex_mapping["c_uimm9lo"] = "uimm[7:6]"
latex_mapping["c_uimm9hi"] = "uimm[5:4$\\vert$8]"
latex_mapping["c_nzimm6lo"] = "nzimm[4:0]"
latex_mapping["c_nzimm6hi"] = "nzimm[5]"
latex_mapping["c_imm6lo"] = "imm[4:0]"
latex_mapping["c_imm6hi"] = "imm[5]"
latex_mapping["c_nzimm10hi"] = "nzimm[9]"
latex_mapping["c_nzimm10lo"] = "nzimm[4$\\vert$6$\\vert$8:7$\\vert$5]"
latex_mapping["c_nzimm18hi"] = "nzimm[17]"
latex_mapping["c_nzimm18lo"] = "nzimm[16:12]"
latex_mapping["c_imm12"] = (
"imm[11$\\vert$4$\\vert$9:8$\\vert$10$\\vert$6$\\vert$7$\\vert$3:1$\\vert$5]"
)
latex_mapping["c_bimm9lo"] = "imm[7:6$\\vert$2:1$\\vert$5]"
latex_mapping["c_bimm9hi"] = "imm[8$\\vert$4:3]"
latex_mapping["c_nzuimm5"] = "nzuimm[4:0]"
latex_mapping["c_nzuimm6lo"] = "nzuimm[4:0]"
latex_mapping["c_nzuimm6hi"] = "nzuimm[5]"
latex_mapping["c_uimm8splo"] = "uimm[4:2$\\vert$7:6]"
latex_mapping["c_uimm8sphi"] = "uimm[5]"
latex_mapping["c_uimm8sp_s"] = "uimm[5:2$\\vert$7:6]"
latex_mapping["c_uimm10splo"] = "uimm[4$\\vert$9:6]"
latex_mapping["c_uimm10sphi"] = "uimm[5]"
latex_mapping["c_uimm9splo"] = "uimm[4:3$\\vert$8:6]"
latex_mapping["c_uimm9sphi"] = "uimm[5]"
latex_mapping["c_uimm10sp_s"] = "uimm[5:4$\\vert$9:6]"
latex_mapping["c_uimm9sp_s"] = "uimm[5:3$\\vert$8:6]"
latex_mapping = {
"imm12": "imm[11:0]",
"rs1": "rs1",
"rs2": "rs2",
"rd": "rd",
"imm20": "imm[31:12]",
"bimm12hi": "imm[12$\\vert$10:5]",
"bimm12lo": "imm[4:1$\\vert$11]",
"imm12hi": "imm[11:5]",
"imm12lo": "imm[4:0]",
"jimm20": "imm[20$\\vert$10:1$\\vert$11$\\vert$19:12]",
"zimm": "uimm",
"shamtw": "shamt",
"shamtd": "shamt",
"shamtq": "shamt",
"rd_p": "rd\\,$'$",
"rs1_p": "rs1\\,$'$",
"rs2_p": "rs2\\,$'$",
"rd_rs1_n0": "rd/rs$\\neq$0",
"rd_rs1_p": "rs1\\,$'$/rs2\\,$'$",
"c_rs2": "rs2",
"c_rs2_n0": "rs2$\\neq$0",
"rd_n0": "rd$\\neq$0",
"rs1_n0": "rs1$\\neq$0",
"c_rs1_n0": "rs1$\\neq$0",
"rd_rs1": "rd/rs1",
"zimm6hi": "uimm[5]",
"zimm6lo": "uimm[4:0]",
"c_nzuimm10": "nzuimm[5:4$\\vert$9:6$\\vert$2$\\vert$3]",
"c_uimm7lo": "uimm[2$\\vert$6]",
"c_uimm7hi": "uimm[5:3]",
"c_uimm8lo": "uimm[7:6]",
"c_uimm8hi": "uimm[5:3]",
"c_uimm9lo": "uimm[7:6]",
"c_uimm9hi": "uimm[5:4$\\vert$8]",
"c_nzimm6lo": "nzimm[4:0]",
"c_nzimm6hi": "nzimm[5]",
"c_imm6lo": "imm[4:0]",
"c_imm6hi": "imm[5]",
"c_nzimm10hi": "nzimm[9]",
"c_nzimm10lo": "nzimm[4$\\vert$6$\\vert$8:7$\\vert$5]",
"c_nzimm18hi": "nzimm[17]",
"c_nzimm18lo": "nzimm[16:12]",
"c_imm12": "imm[11$\\vert$4$\\vert$9:8$\\vert$10$\\vert$6$\\vert$7$\\vert$3:1$\\vert$5]",
"c_bimm9lo": "imm[7:6$\\vert$2:1$\\vert$5]",
"c_bimm9hi": "imm[8$\\vert$4:3]",
"c_nzuimm5": "nzuimm[4:0]",
"c_nzuimm6lo": "nzuimm[4:0]",
"c_nzuimm6hi": "nzuimm[5]",
"c_uimm8splo": "uimm[4:2$\\vert$7:6]",
"c_uimm8sphi": "uimm[5]",
"c_uimm8sp_s": "uimm[5:2$\\vert$7:6]",
"c_uimm10splo": "uimm[4$\\vert$9:6]",
"c_uimm10sphi": "uimm[5]",
"c_uimm9splo": "uimm[4:3$\\vert$8:6]",
"c_uimm9sphi": "uimm[5]",
"c_uimm10sp_s": "uimm[5:4$\\vert$9:6]",
"c_uimm9sp_s": "uimm[5:3$\\vert$8:6]",
}
# created a dummy instruction-dictionary like dictionary for all the instruction
# types so that the same logic can be used to create their tables
latex_inst_type = {}
latex_inst_type["R-type"] = {}
latex_inst_type["R-type"]["variable_fields"] = [
"opcode",
"rd",
"funct3",
"rs1",
"rs2",
"funct7",
latex_inst_type = {
"R-type": {
"variable_fields": ["opcode", "rd", "funct3", "rs1", "rs2", "funct7"],
},
"R4-type": {
"variable_fields": ["opcode", "rd", "funct3", "rs1", "rs2", "funct2", "rs3"],
},
"I-type": {
"variable_fields": ["opcode", "rd", "funct3", "rs1", "imm12"],
},
"S-type": {
"variable_fields": ["opcode", "imm12lo", "funct3", "rs1", "rs2", "imm12hi"],
},
"B-type": {
"variable_fields": ["opcode", "bimm12lo", "funct3", "rs1", "rs2", "bimm12hi"],
},
"U-type": {
"variable_fields": ["opcode", "rd", "imm20"],
},
"J-type": {
"variable_fields": ["opcode", "rd", "jimm20"],
},
}
latex_fixed_fields = [
(31, 25),
(24, 20),
(19, 15),
(14, 12),
(11, 7),
(6, 0),
]
latex_inst_type["R4-type"] = {}
latex_inst_type["R4-type"]["variable_fields"] = [
"opcode",
"rd",
"funct3",
"rs1",
"rs2",
"funct2",
"rs3",
]
latex_inst_type["I-type"] = {}
latex_inst_type["I-type"]["variable_fields"] = [
"opcode",
"rd",
"funct3",
"rs1",
"imm12",
]
latex_inst_type["S-type"] = {}
latex_inst_type["S-type"]["variable_fields"] = [
"opcode",
"imm12lo",
"funct3",
"rs1",
"rs2",
"imm12hi",
]
latex_inst_type["B-type"] = {}
latex_inst_type["B-type"]["variable_fields"] = [
"opcode",
"bimm12lo",
"funct3",
"rs1",
"rs2",
"bimm12hi",
]
latex_inst_type["U-type"] = {}
latex_inst_type["U-type"]["variable_fields"] = ["opcode", "rd", "imm20"]
latex_inst_type["J-type"] = {}
latex_inst_type["J-type"]["variable_fields"] = ["opcode", "rd", "jimm20"]
latex_fixed_fields = []
latex_fixed_fields.append((31, 25))
latex_fixed_fields.append((24, 20))
latex_fixed_fields.append((19, 15))
latex_fixed_fields.append((14, 12))
latex_fixed_fields.append((11, 7))
latex_fixed_fields.append((6, 0))
# Pseudo-ops present in the generated encodings.
# By default pseudo-ops are not listed as they are considered aliases

View File

@@ -8,7 +8,7 @@ pp = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=logging.INFO, format="%(levelname)s:: %(message)s")
def make_go(instr_dict):
def make_go(instr_dict: InstrDict):
args = " ".join(sys.argv)
prelude = f"""// Code generated by {args}; DO NOT EDIT."""

View File

@@ -1,11 +1,6 @@
import collections
import copy
import glob
import logging
import os
import pprint
import re
import sys
from typing import TextIO
from constants import *
from shared_utils import *
@@ -117,7 +112,9 @@ def make_latex_table():
# instructions listed in list_of_instructions will be dumped into latex.
caption = ""
type_list = ["R-type", "I-type", "S-type", "B-type", "U-type", "J-type"]
dataset_list = [(["_i", "32_i"], "RV32I Base Instruction Set", [], False)]
dataset_list: list[tuple[list[str], str, list[str], bool]] = [
(["_i", "32_i"], "RV32I Base Instruction Set", [], False)
]
dataset_list.append((["_i"], "", ["fence_tso", "pause"], True))
make_ext_latex_table(type_list, dataset_list, latex_file, 32, caption)
@@ -184,7 +181,13 @@ def make_latex_table():
latex_file.close()
def make_ext_latex_table(type_list, dataset, latex_file, ilen, caption):
def make_ext_latex_table(
type_list: "list[str]",
dataset: "list[tuple[list[str], str, list[str], bool]]",
latex_file: TextIO,
ilen: int,
caption: str,
):
"""
For a given collection of extensions this function dumps out a complete
latex table which includes the encodings of the instructions.
@@ -285,7 +288,7 @@ def make_ext_latex_table(type_list, dataset, latex_file, ilen, caption):
# iterate ovr each instruction type and create a table entry
for t in type_dict:
fields = []
fields: list[tuple[int, int, str]] = []
# first capture all "arguments" of the type (funct3, funct7, rd, etc)
# and capture their positions using arg_lut.
@@ -332,7 +335,7 @@ def make_ext_latex_table(type_list, dataset, latex_file, ilen, caption):
# for each entry in the dataset create a table
content = ""
for ext_list, title, filter_list, include_pseudo in dataset:
instr_dict = {}
instr_dict: InstrDict = {}
# for all extensions list in ext_list, create a dictionary of
# instructions associated with those extensions.

View File

@@ -5,14 +5,14 @@ import logging
import pprint
import sys
from c_utils import *
from chisel_utils import *
from constants import *
from go_utils import *
from latex_utils import *
from rust_utils import *
from shared_utils import *
from sverilog_utils import *
from c_utils import make_c
from chisel_utils import make_chisel
from constants import emitted_pseudo_ops
from go_utils import make_go
from latex_utils import make_latex_table, make_priv_latex_table
from rust_utils import make_rust
from shared_utils import add_segmented_vls_insn, create_inst_dict
from sverilog_utils import make_sverilog
LOG_FORMAT = "%(levelname)s:: %(message)s"
LOG_LEVEL = logging.INFO
@@ -20,7 +20,8 @@ LOG_LEVEL = logging.INFO
pretty_printer = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT)
if __name__ == "__main__":
def main():
print(f"Running with args : {sys.argv}")
extensions = sys.argv[1:]
@@ -80,3 +81,7 @@ if __name__ == "__main__":
logging.info("instr-table.tex generated successfully")
make_priv_latex_table()
logging.info("priv-instr-table.tex generated successfully")
if __name__ == "__main__":
main()

4
pyrightconfig.json Normal file
View File

@@ -0,0 +1,4 @@
{
"typeCheckingMode": "strict",
"pythonVersion": "3.6",
}

View File

@@ -10,7 +10,7 @@ pp = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=logging.INFO, format="%(levelname)s:: %(message)s")
def make_rust(instr_dict):
def make_rust(instr_dict: InstrDict):
mask_match_str = ""
for i in instr_dict:
mask_match_str += f'const MATCH_{i.upper().replace(".","_")}: u32 = {(instr_dict[i]["match"])};\n'

View File

@@ -6,6 +6,7 @@ import os
import pprint
import re
from itertools import chain
from typing import Dict, TypedDict
from constants import *
@@ -16,30 +17,36 @@ pretty_printer = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=LOG_LEVEL, format=LOG_FORMAT)
def log_and_exit(message: str):
"""
Log an error message and then exit with EXIT_FAILURE.
"""
logging.error(message)
raise SystemExit(1)
# Initialize encoding to 32-bit '-' values
def initialize_encoding(bits=32):
def initialize_encoding(bits: int = 32) -> "list[str]":
"""Initialize encoding with '-' to represent don't care bits."""
return ["-"] * bits
# Validate bit range and value
def validate_bit_range(msb, lsb, entry_value, line):
def validate_bit_range(msb: int, lsb: int, entry_value: int, line: str):
"""Validate the bit range and entry value."""
if msb < lsb:
logging.error(
log_and_exit(
f'{line.split(" ")[0]:<10} has position {msb} less than position {lsb} in its encoding'
)
raise SystemExit(1)
if entry_value >= (1 << (msb - lsb + 1)):
logging.error(
log_and_exit(
f'{line.split(" ")[0]:<10} has an illegal value {entry_value} assigned as per the bit width {msb - lsb}'
)
raise SystemExit(1)
# Split the instruction line into name and remaining part
def parse_instruction_line(line):
def parse_instruction_line(line: str) -> "tuple[str, str]":
"""Parse the instruction name and the remaining encoding details."""
name, remaining = line.split(" ", 1)
name = name.replace(".", "_") # Replace dots for compatibility
@@ -48,17 +55,18 @@ def parse_instruction_line(line):
# Verify Overlapping Bits
def check_overlapping_bits(encoding, ind, line):
def check_overlapping_bits(encoding: "list[str]", ind: int, line: str):
"""Check for overlapping bits in the encoding."""
if encoding[31 - ind] != "-":
logging.error(
log_and_exit(
f'{line.split(" ")[0]:<10} has {ind} bit overlapping in its opcodes'
)
raise SystemExit(1)
# Update encoding for fixed ranges
def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line):
def update_encoding_for_fixed_range(
encoding: "list[str]", msb: int, lsb: int, entry_value: int, line: str
):
"""
Update encoding bits for a given bit range.
Checks for overlapping bits and assigns the value accordingly.
@@ -70,7 +78,7 @@ def update_encoding_for_fixed_range(encoding, msb, lsb, entry_value, line):
# Process fixed bit patterns
def process_fixed_ranges(remaining, encoding, line):
def process_fixed_ranges(remaining: str, encoding: "list[str]", line: str):
"""Process fixed bit ranges in the encoding."""
for s2, s1, entry in fixed_ranges.findall(remaining):
msb, lsb, entry_value = int(s2), int(s1), int(entry, 0)
@@ -83,9 +91,9 @@ def process_fixed_ranges(remaining, encoding, line):
# Process single bit assignments
def process_single_fixed(remaining, encoding, line):
def process_single_fixed(remaining: str, encoding: "list[str]", line: str):
"""Process single fixed assignments in the encoding."""
for lsb, value, drop in single_fixed.findall(remaining):
for lsb, value, _drop in single_fixed.findall(remaining):
lsb = int(lsb, 0)
value = int(value, 0)
@@ -94,7 +102,7 @@ def process_single_fixed(remaining, encoding, line):
# Main function to check argument look-up table
def check_arg_lut(args, encoding_args, name):
def check_arg_lut(args: "list[str]", encoding_args: "list[str]", name: str):
"""Check if arguments are present in arg_lut."""
for arg in args:
if arg not in arg_lut:
@@ -104,30 +112,28 @@ def check_arg_lut(args, encoding_args, name):
# Handle missing argument mappings
def handle_arg_lut_mapping(arg, name):
def handle_arg_lut_mapping(arg: str, name: str):
"""Handle cases where an argument needs to be mapped to an existing one."""
parts = arg.split("=")
if len(parts) == 2:
existing_arg, new_arg = parts
existing_arg, _new_arg = parts
if existing_arg in arg_lut:
arg_lut[arg] = arg_lut[existing_arg]
else:
logging.error(
log_and_exit(
f" Found field {existing_arg} in variable {arg} in instruction {name} "
f"whose mapping in arg_lut does not exist"
)
raise SystemExit(1)
else:
logging.error(
log_and_exit(
f" Found variable {arg} in instruction {name} "
f"whose mapping in arg_lut does not exist"
)
raise SystemExit(1)
return arg
# Update encoding args with variables
def update_encoding_args(encoding_args, arg, msb, lsb):
def update_encoding_args(encoding_args: "list[str]", arg: str, msb: int, lsb: int):
"""Update encoding arguments and ensure no overlapping."""
for ind in range(lsb, msb + 1):
check_overlapping_bits(encoding_args, ind, arg)
@@ -135,15 +141,26 @@ def update_encoding_args(encoding_args, arg, msb, lsb):
# Compute match and mask
def convert_encoding_to_match_mask(encoding):
def convert_encoding_to_match_mask(encoding: "list[str]") -> "tuple[str, str]":
"""Convert the encoding list to match and mask strings."""
match = "".join(encoding).replace("-", "0")
mask = "".join(encoding).replace("0", "1").replace("-", "0")
return hex(int(match, 2)), hex(int(mask, 2))
class SingleInstr(TypedDict):
encoding: str
variable_fields: "list[str]"
extension: "list[str]"
match: str
mask: str
InstrDict = Dict[str, SingleInstr]
# Processing main function for a line in the encoding file
def process_enc_line(line, ext):
def process_enc_line(line: str, ext: str) -> "tuple[str, SingleInstr]":
"""
This function processes each line of the encoding files (rv*). As part of
the processing, the function ensures that the encoding is legal through the
@@ -199,13 +216,13 @@ def process_enc_line(line, ext):
# Extract ISA Type
def extract_isa_type(ext_name):
def extract_isa_type(ext_name: str) -> str:
"""Extracts the ISA type from the extension name."""
return ext_name.split("_")[0]
# Verify the types for RV*
def is_rv_variant(type1, type2):
def is_rv_variant(type1: str, type2: str) -> bool:
"""Checks if the types are RV variants (rv32/rv64)."""
return (type2 == "rv" and type1 in {"rv32", "rv64"}) or (
type1 == "rv" and type2 in {"rv32", "rv64"}
@@ -213,77 +230,79 @@ def is_rv_variant(type1, type2):
# Check for same base ISA
def has_same_base_isa(type1, type2):
def has_same_base_isa(type1: str, type2: str) -> bool:
"""Determines if the two ISA types share the same base."""
return type1 == type2 or is_rv_variant(type1, type2)
# Compare the base ISA type of a given extension name against a list of extension names
def same_base_isa(ext_name, ext_name_list):
def same_base_isa(ext_name: str, ext_name_list: "list[str]") -> bool:
"""Checks if the base ISA type of ext_name matches any in ext_name_list."""
type1 = extract_isa_type(ext_name)
return any(has_same_base_isa(type1, extract_isa_type(ext)) for ext in ext_name_list)
# Pad two strings to equal length
def pad_to_equal_length(str1, str2, pad_char="-"):
def pad_to_equal_length(str1: str, str2: str, pad_char: str = "-") -> "tuple[str, str]":
"""Pads two strings to equal length using the given padding character."""
max_len = max(len(str1), len(str2))
return str1.rjust(max_len, pad_char), str2.rjust(max_len, pad_char)
# Check compatibility for two characters
def has_no_conflict(char1, char2):
def has_no_conflict(char1: str, char2: str) -> bool:
"""Checks if two characters are compatible (either matching or don't-care)."""
return char1 == "-" or char2 == "-" or char1 == char2
# Conflict check between two encoded strings
def overlaps(x, y):
def overlaps(x: str, y: str) -> bool:
"""Checks if two encoded strings overlap without conflict."""
x, y = pad_to_equal_length(x, y)
return all(has_no_conflict(x[i], y[i]) for i in range(len(x)))
# Check presence of keys in dictionary.
def is_in_nested_dict(a, key1, key2):
def is_in_nested_dict(a: "dict[str, set[str]]", key1: str, key2: str) -> bool:
"""Checks if key2 exists in the dictionary under key1."""
return key1 in a and key2 in a[key1]
# Overlap allowance
def overlap_allowed(a, x, y):
def overlap_allowed(a: "dict[str, set[str]]", x: str, y: str) -> bool:
"""Determines if overlap is allowed between x and y based on nested dictionary checks"""
return is_in_nested_dict(a, x, y) or is_in_nested_dict(a, y, x)
# Check overlap allowance between extensions
def extension_overlap_allowed(x, y):
def extension_overlap_allowed(x: str, y: str) -> bool:
"""Checks if overlap is allowed between two extensions using the overlapping_extensions dictionary."""
return overlap_allowed(overlapping_extensions, x, y)
# Check overlap allowance between instructions
def instruction_overlap_allowed(x, y):
def instruction_overlap_allowed(x: str, y: str) -> bool:
"""Checks if overlap is allowed between two instructions using the overlapping_instructions dictionary."""
return overlap_allowed(overlapping_instructions, x, y)
# Check 'nf' field
def is_segmented_instruction(instruction):
def is_segmented_instruction(instruction: SingleInstr) -> bool:
"""Checks if an instruction contains the 'nf' field."""
return "nf" in instruction["variable_fields"]
# Expand 'nf' fields
def update_with_expanded_instructions(updated_dict, key, value):
def update_with_expanded_instructions(
updated_dict: InstrDict, key: str, value: SingleInstr
):
"""Expands 'nf' fields in the instruction dictionary and updates it with new instructions."""
for new_key, new_value in expand_nf_field(key, value):
updated_dict[new_key] = new_value
# Process instructions, expanding segmented ones and updating the dictionary
def add_segmented_vls_insn(instr_dict):
def add_segmented_vls_insn(instr_dict: InstrDict) -> InstrDict:
"""Processes instructions, expanding segmented ones and updating the dictionary."""
# Use dictionary comprehension for efficiency
return dict(
@@ -299,7 +318,9 @@ def add_segmented_vls_insn(instr_dict):
# Expand the 'nf' field in the instruction dictionary
def expand_nf_field(name, single_dict):
def expand_nf_field(
name: str, single_dict: SingleInstr
) -> "list[tuple[str, SingleInstr]]":
"""Validate and prepare the instruction dictionary."""
validate_nf_field(single_dict, name)
remove_nf_field(single_dict)
@@ -322,29 +343,33 @@ def expand_nf_field(name, single_dict):
# Validate the presence of 'nf'
def validate_nf_field(single_dict, name):
def validate_nf_field(single_dict: SingleInstr, name: str):
"""Validates the presence of 'nf' in variable fields before expansion."""
if "nf" not in single_dict["variable_fields"]:
logging.error(f"Cannot expand nf field for instruction {name}")
raise SystemExit(1)
log_and_exit(f"Cannot expand nf field for instruction {name}")
# Remove 'nf' from variable fields
def remove_nf_field(single_dict):
def remove_nf_field(single_dict: SingleInstr):
"""Removes 'nf' from variable fields in the instruction dictionary."""
single_dict["variable_fields"].remove("nf")
# Update the mask to include the 'nf' field
def update_mask(single_dict):
def update_mask(single_dict: SingleInstr):
"""Updates the mask to include the 'nf' field in the instruction dictionary."""
single_dict["mask"] = hex(int(single_dict["mask"], 16) | 0b111 << 29)
# Create an expanded instruction
def create_expanded_instruction(
name, single_dict, nf, name_expand_index, base_match, encoding_prefix
):
name: str,
single_dict: SingleInstr,
nf: int,
name_expand_index: int,
base_match: int,
encoding_prefix: str,
) -> "tuple[str, SingleInstr]":
"""Creates an expanded instruction based on 'nf' value."""
new_single_dict = copy.deepcopy(single_dict)
@@ -363,7 +388,7 @@ def create_expanded_instruction(
# Return a list of relevant lines from the specified file
def read_lines(file):
def read_lines(file: str) -> "list[str]":
"""Reads lines from a file and returns non-blank, non-comment lines."""
with open(file) as fp:
lines = (line.rstrip() for line in fp)
@@ -371,7 +396,9 @@ def read_lines(file):
# Update the instruction dictionary
def process_standard_instructions(lines, instr_dict, file_name):
def process_standard_instructions(
lines: "list[str]", instr_dict: InstrDict, file_name: str
):
"""Processes standard instructions from the given lines and updates the instruction dictionary."""
for line in lines:
if "$import" in line or "$pseudo" in line:
@@ -409,7 +436,12 @@ def process_standard_instructions(lines, instr_dict, file_name):
# Incorporate pseudo instructions into the instruction dictionary based on given conditions
def process_pseudo_instructions(
lines, instr_dict, file_name, opcodes_dir, include_pseudo, include_pseudo_ops
lines: "list[str]",
instr_dict: InstrDict,
file_name: str,
opcodes_dir: str,
include_pseudo: bool,
include_pseudo_ops: "list[str]",
):
"""Processes pseudo instructions from the given lines and updates the instruction dictionary."""
for line in lines:
@@ -433,12 +465,15 @@ def process_pseudo_instructions(
else:
if single_dict["match"] != instr_dict[name]["match"]:
instr_dict[f"{name}_pseudo"] = single_dict
elif single_dict["extension"] not in instr_dict[name]["extension"]:
# TODO: This expression is always false since both sides are list[str].
elif single_dict["extension"] not in instr_dict[name]["extension"]: # type: ignore
instr_dict[name]["extension"].extend(single_dict["extension"])
# Integrate imported instructions into the instruction dictionary
def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir):
def process_imported_instructions(
lines: "list[str]", instr_dict: InstrDict, file_name: str, opcodes_dir: str
):
"""Processes imported instructions from the given lines and updates the instruction dictionary."""
for line in lines:
if "$import" not in line:
@@ -464,7 +499,7 @@ def process_imported_instructions(lines, instr_dict, file_name, opcodes_dir):
# Locate the path of the specified extension file, checking fallback directories
def find_extension_file(ext, opcodes_dir):
def find_extension_file(ext: str, opcodes_dir: str):
"""Finds the extension file path, considering the unratified directory if necessary."""
ext_file = f"{opcodes_dir}/{ext}"
if not os.path.exists(ext_file):
@@ -475,7 +510,9 @@ def find_extension_file(ext, opcodes_dir):
# Confirm the presence of an original instruction in the corresponding extension file.
def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst):
def validate_instruction_in_extension(
inst: str, ext_file: str, file_name: str, pseudo_inst: str
):
"""Validates if the original instruction exists in the dependent extension."""
found = False
for oline in open(ext_file):
@@ -489,7 +526,11 @@ def validate_instruction_in_extension(inst, ext_file, file_name, pseudo_inst):
# Construct a dictionary of instructions filtered by specified criteria
def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]):
def create_inst_dict(
file_filter: "list[str]",
include_pseudo: bool = False,
include_pseudo_ops: "list[str]" = [],
) -> InstrDict:
"""Creates a dictionary of instructions based on the provided file filters."""
"""
@@ -522,7 +563,7 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]):
is not already present; otherwise, it is skipped.
"""
opcodes_dir = os.path.dirname(os.path.realpath(__file__))
instr_dict = {}
instr_dict: InstrDict = {}
file_names = [
file
@@ -559,10 +600,10 @@ def create_inst_dict(file_filter, include_pseudo=False, include_pseudo_ops=[]):
# Extracts the extensions used in an instruction dictionary
def instr_dict_2_extensions(instr_dict):
def instr_dict_2_extensions(instr_dict: InstrDict) -> "list[str]":
return list({item["extension"][0] for item in instr_dict.values()})
# Returns signed interpretation of a value within a given width
def signed(value, width):
def signed(value: int, width: int) -> int:
return value if 0 <= value < (1 << (width - 1)) else value - (1 << width)

View File

@@ -7,7 +7,7 @@ pp = pprint.PrettyPrinter(indent=2)
logging.basicConfig(level=logging.INFO, format="%(levelname)s:: %(message)s")
def make_sverilog(instr_dict):
def make_sverilog(instr_dict: InstrDict):
names_str = ""
for i in instr_dict:
names_str += f" localparam [31:0] {i.upper().replace('.','_'):<18s} = 32'b{instr_dict[i]['encoding'].replace('-','?')};\n"

View File

@@ -12,7 +12,7 @@ class EncodingLineTest(unittest.TestCase):
logger = logging.getLogger()
logger.disabled = True
def assertError(self, string):
def assertError(self, string: str):
self.assertRaises(SystemExit, process_enc_line, string, "rv_i")
def test_lui(self):