Skip to content

Commit 1976182

Browse files
author
Jean Luc Bouchot
committed
Genlib working for dgemm
1 parent 854e682 commit 1976182

1 file changed

Lines changed: 183 additions & 5 deletions

File tree

run_tapenade_blas.py

Lines changed: 183 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -643,10 +643,11 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
643643
complex_vars = set()
644644
integer_vars = set()
645645
char_vars = set()
646+
array_vars = set()
646647

647648
# Find the argument declaration section
648649
lines = content.split('\n')
649-
in_args_section = False
650+
in_args_section = False # This variable is used nowhere
650651

651652
for i, line in enumerate(lines):
652653
line_stripped = line.strip()
@@ -661,6 +662,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
661662

662663
# Also look for the actual declaration lines (not in comments)
663664
if line_stripped and not line_stripped.startswith('*') and not line_stripped.startswith('C '):
665+
# Look for ARRAY variables
666+
is_array = ('(' in line_stripped) and (')' in line_stripped)
664667

665668
# Parse variable declarations
666669
if line_stripped.startswith('REAL') or line_stripped.startswith('DOUBLE PRECISION') or line_stripped.startswith('FLOAT'):
@@ -677,6 +680,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
677680
var = re.sub(r'\*.*$', '', var)
678681
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
679682
real_vars.add(var)
683+
if is_array:
684+
array_vars.add(var)
680685

681686
elif line_stripped.startswith('INTEGER'):
682687
int_decl = re.search(r'INTEGER\s+(.+)', line_stripped, re.IGNORECASE)
@@ -689,6 +694,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
689694
var = re.sub(r'\*.*$', '', var)
690695
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
691696
integer_vars.add(var)
697+
if is_array:
698+
array_vars.add(var)
692699

693700
elif line_stripped.startswith('CHARACTER'):
694701
char_decl = re.search(r'CHARACTER\s+(.+)', line_stripped, re.IGNORECASE)
@@ -701,6 +708,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
701708
var = re.sub(r'\*.*$', '', var)
702709
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
703710
char_vars.add(var)
711+
if is_array:
712+
array_vars.add(var)
704713

705714
elif line_stripped.startswith('COMPLEX'):
706715
# Extract variable names from COMPLEX declaration
@@ -716,6 +725,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
716725
var = re.sub(r'\*.*$', '', var)
717726
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
718727
complex_vars.add(var) # Add complex variables to complex_vars
728+
if is_array:
729+
array_vars.add(var)
719730

720731
# For FUNCTIONs with explicit return types, add function name to appropriate variable set
721732
if func_type == 'FUNCTION':
@@ -847,7 +858,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
847858
'real_vars': real_vars,
848859
'complex_vars': complex_vars,
849860
'integer_vars': integer_vars,
850-
'char_vars': char_vars
861+
'char_vars': char_vars,
862+
'array_vars': array_vars
851863
}
852864

853865
return func_name, valid_inputs, valid_outputs, inout_vars, func_type, params, warnings, param_types, has_sufficient_docs
@@ -8018,6 +8030,7 @@ def main():
80188030
help="AD modes to generate: d (forward scalar), dv (forward vector), b (reverse scalar), bv (reverse vector), all (all modes). Default: all")
80198031
ap.add_argument("--nbdirsmax", type=int, default=4, help="Maximum number of derivative directions for vector mode (default: 4)")
80208032
ap.add_argument("--flat", action="store_true", help="Use flat directory structure (all files in function directory, single DIFFSIZES.inc)")
8033+
ap.add_argument("--debug-genlib", default=None, required=False)
80218034
ap.add_argument("--extra", nargs=argparse.REMAINDER, help="Extra args passed to Tapenade after -d/-r", default=[])
80228035
args = ap.parse_args()
80238036

@@ -8110,6 +8123,9 @@ def run_task(task):
81108123

81118124
# Parse the Fortran function to get signature
81128125
func_name, inputs, outputs, inout_vars, func_type, all_params, parse_warnings, param_types, has_sufficient_docs = parse_fortran_function(src)
8126+
print(f"INPUTS = {inputs}")
8127+
print(f"ALL_PARAMS = {all_params}")
8128+
print(f"PARAM_TYPES = {param_types}")
81138129

81148130
if not func_name:
81158131
print(f"Skipping {src}: Could not parse function signature", file=sys.stderr)
@@ -8144,8 +8160,87 @@ def run_task(task):
81448160
# Create output directory structure
81458161
flat_mode = args.flat
81468162
mode_dirs = {}
8147-
8148-
if flat_mode:
8163+
8164+
if (args.debug_genlib):
8165+
# if (False):
8166+
# When generating the general lib useful to Tapenade, we will save everything in a tmp file
8167+
# and only the lib in a local folder used to concatenate everything afterwards.
8168+
tmp_dir = Path("TMPGENLIB").resolve()
8169+
tmp_dir.mkdir(parents=True, exist_ok=True)
8170+
func_out_dir = tmp_dir
8171+
genlib_dir = out_dir
8172+
genlib_dir.mkdir(parents=True, exist_ok=True)
8173+
if run_d:
8174+
mode_dirs['d'] = tmp_dir
8175+
if run_b:
8176+
mode_dirs['b'] = tmp_dir
8177+
if run_dv:
8178+
mode_dirs['dv'] = tmp_dir
8179+
if run_bv:
8180+
mode_dirs['bv'] = tmp_dir
8181+
8182+
def convert_tap_result2genlib_format(l: str) :
8183+
out = []
8184+
infos = l.split("[")[1]
8185+
use_infos = True
8186+
for c in infos[3:]: # Don't bother with the first 3 elements
8187+
if(c == "]"):
8188+
break
8189+
if use_infos:
8190+
if(c == "("):
8191+
use_infos = False
8192+
else:
8193+
out = out + [("0" if c == "." else "1" )]
8194+
else:
8195+
if(c == ")"):
8196+
use_infos = True
8197+
8198+
return out
8199+
8200+
def parse_tap_trace4inout(fname):
8201+
with open(fname, "r") as f:
8202+
sought_after = " ===================== IN-OUT ANALYSIS OF UNIT "
8203+
l = f.readline()
8204+
while(not l.startswith(sought_after)):
8205+
l = f.readline()
8206+
8207+
# Now we read the next one, and start looking at the arguments
8208+
var2idx_mapping = dict()
8209+
idx2var_mapping = dict()
8210+
l = f.readline().strip()
8211+
for v in l.split(" ")[3:]: # The first three variables are useless in the sense that they are only used internally for tapenade
8212+
# v is a string resembling "[id]varName" where id is an identifier internal to tapenade for the zone of variable varName
8213+
not_quite_id, var_name = v.split("]")
8214+
idx = int(not_quite_id[1:])
8215+
var2idx_mapping[var_name] = idx
8216+
idx2var_mapping[idx] = var_name
8217+
8218+
# Now that the mapping has been parsed, we move towards the end of the analysis phase, and extract the summary
8219+
sought_after = "terminateFGForUnit Unit"
8220+
while(not l.startswith(sought_after)):
8221+
l = f.readline()
8222+
# We have found our signal to read the results
8223+
# It is always four lines looking like this
8224+
# N [111111..11(1)111111] ---> corresponds to NotReadNotWritten, probably useless
8225+
# R [...1111111(1)11111.] ---> corresponds to ReadNotWritten
8226+
# W [..1.......(1).....1] ---> corresponds to NotReadThenWritten ==> Need to check what the 1 in third position means
8227+
# RW [..........(1).....1] ---> corresponds to ReadThenWritten
8228+
l = f.readline()
8229+
# Discard the not read not written elements
8230+
l = f.readline()
8231+
# We deal with the ReadNotWritten information
8232+
read_not_written = convert_tap_result2genlib_format(l)
8233+
l = f.readline()
8234+
# Deal with NotReadThenWritten
8235+
not_read_then_written = convert_tap_result2genlib_format(l)
8236+
l = f.readline()
8237+
# Deal with ReadThenWritten
8238+
read_then_written = convert_tap_result2genlib_format(l)
8239+
8240+
return read_not_written, not_read_then_written, read_then_written, var2idx_mapping
8241+
8242+
8243+
elif flat_mode:
81498244
# Flat mode with organized subdirectories: src/, test/, include/
81508245
src_dir = out_dir / 'src'
81518246
test_dir = out_dir / 'test'
@@ -8188,7 +8283,7 @@ def run_task(task):
81888283
mode_dirs['bv'].mkdir(parents=True, exist_ok=True)
81898284

81908285
# Update log path to be in the function subdirectory
8191-
func_log_path = func_out_dir / (src.stem + ".tapenade.log")
8286+
# func_log_path = func_out_dir / (src.stem + ".tapenade.log") # ISNT THIS COMPLETELY USELESS??
81928287

81938288
# Find dependency files
81948289
called_functions = parse_function_calls(src)
@@ -8240,11 +8335,14 @@ def run_task(task):
82408335
for dep_file in main_file_removed:
82418336
cmd.append(str(dep_file))
82428337
cmd.extend(list(args.extra))
8338+
if (args.debug_genlib):
8339+
cmd = cmd + ["-traceinout", src.stem]
82438340

82448341
try:
82458342
with open(mode_log_path, "w") as logf:
82468343
logf.write(f"Mode: FORWARD (scalar)\n")
82478344
# Format command for logging (properly quoted for shell copy-paste)
8345+
print("CMD:", cmd)
82488346
cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd)
82498347
logf.write(f"Command: {cmd_str}\n")
82508348
logf.write(f"Function: {func_name}\n")
@@ -8285,6 +8383,42 @@ def run_task(task):
82858383
pass
82868384
print(f" ERROR: Exception during forward mode execution: {e}", file=sys.stderr)
82878385
return_codes["forward"] = 999
8386+
8387+
if (args.debug_genlib) : # Everything went well, and we are trying to generate the external lib
8388+
read_not_written, not_read_then_written, read_then_written, var2idx = parse_tap_trace4inout(mode_log_path)
8389+
print(var2idx)
8390+
param_2_tap_reordering = [var2idx[p.lower()]-3 for p in all_params]
8391+
print("Reordering is {}".format(param_2_tap_reordering))
8392+
with open("externalLib", "a") as f:
8393+
f.write(("function " if func_type == 'FUNCTION' else "subroutine ") + src.stem + ":\n")
8394+
indent = " "
8395+
f.write(indent + "external:\n")
8396+
shape = "(" + ", ".join(["param " + str(i) for i in range(1,len(all_params)+1)]) + ")" ## TODO: Need to add ', return' in case of a function,. dpeending on whether it is within the all params or not
8397+
f.write(indent + "shape: " + shape + "\n")
8398+
types = []
8399+
for p in all_params:
8400+
current_type = ""
8401+
if p in param_types['real_vars']:
8402+
current_type = "metavar float" # We should probably be more precise in order to handle mixed precision things
8403+
# Namely, adapt to
8404+
# modifiedType(modifiers(ident double), float() for double / REAL*8
8405+
# float() for single precision
8406+
elif p in param_types['complex_vars']:
8407+
current_type = "metavar complex"
8408+
# Similar to the real variables, we should be able to be more precise in terms of precision of the complex variable
8409+
elif p in param_types['integer_vars']:
8410+
current_type = "metavar integer"
8411+
elif p in param_types['char_vars']:
8412+
current_type = "character()"
8413+
if p in param_types['array_vars']: # Will be "is_matrix_var" or "is_array_var" or something along those lines
8414+
current_type = "arrayType(" + current_type + ", dimColons())"
8415+
8416+
types.append(current_type)
8417+
types = "(" + ", ".join(types) + ")"
8418+
f.write(indent + "type: " + types + "\n")
8419+
f.write(indent + "ReadNotWritten: (" + ", ".join([read_not_written[i] for i in param_2_tap_reordering]) + ")\n")
8420+
f.write(indent + "NotReadThenWritten: (" + ", ".join([not_read_then_written[i] for i in param_2_tap_reordering]) + ")\n")
8421+
f.write(indent + "ReadThenWritten: (" + ", ".join([read_then_written[i] for i in param_2_tap_reordering]) + ")\n")
82888422

82898423
# Run scalar reverse mode (b)
82908424
if run_b:
@@ -8383,6 +8517,7 @@ def run_task(task):
83838517
try:
83848518
with open(mode_log_path, "w") as logf:
83858519
logf.write(f"Mode: FORWARD VECTOR\n")
8520+
print("CMD:", cmd)
83868521
# Format command for logging (properly quoted for shell copy-paste)
83878522
cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd)
83888523
logf.write(f"Command: {cmd_str}\n")
@@ -8449,6 +8584,7 @@ def run_task(task):
84498584
try:
84508585
with open(mode_log_path, "w") as logf:
84518586
logf.write(f"Mode: REVERSE VECTOR\n")
8587+
print("CMD:", cmd)
84528588
# Format command for logging (properly quoted for shell copy-paste)
84538589
cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd)
84548590
logf.write(f"Command: {cmd_str}\n")
@@ -8679,6 +8815,48 @@ def run_task(task):
86798815
final_rc = max(return_codes.values()) if return_codes else 999
86808816
return (src, final_rc)
86818817

8818+
8819+
8820+
8821+
if (args.debug_genlib):
8822+
'''
8823+
WORKING HERE
8824+
XXXXXXXXXX
8825+
Need to figure out:
8826+
-> The various parameters of a subroutine, their types
8827+
-> Convert these types into Tapenade's GenLib format
8828+
-> Link the number of a parameter in the prototype of the subroutine with its rank in tapenade's inout analysis
8829+
-> Dump everything into a single genlib file
8830+
'''
8831+
# Add tests for existence of file / correct extension / ...
8832+
file_path = fortran_dir / args.debug_genlib
8833+
8834+
tasks = []
8835+
func_stem = file_path.stem.lower()
8836+
rel = file_path.relative_to(fortran_dir)
8837+
out_dir = out_root / rel.parent
8838+
out_dir.mkdir(parents=True, exist_ok=True)
8839+
log_path = out_dir / (file_path.stem + ".tapenade.log")
8840+
# Explicitely force the diff modes for now:
8841+
run_d, run_dv, run_b, run_bv = True, False, True, False
8842+
tasks.append((file_path, out_dir, log_path, run_d, run_dv, run_b, run_bv))
8843+
constraints = parse_parameter_constraints(file_path)
8844+
parsed_function = parse_fortran_function(file_path, suppress_warnings=False)
8845+
args_are = ["function_name", "inputs", "outputs", "inout_vars", "func_type", "params", "warnings", "param_types", "has_sufficient_docs"]
8846+
for idx, v in enumerate(parsed_function):
8847+
print(f"{args_are[idx]} == {v}")
8848+
8849+
run_task(tasks[0])
8850+
8851+
return
8852+
8853+
8854+
8855+
8856+
8857+
8858+
8859+
86828860
# Serial or parallel execution
86838861
results = []
86848862
if args.jobs <= 1:

0 commit comments

Comments
 (0)