-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathequation_generator.py
More file actions
101 lines (82 loc) · 4.03 KB
/
equation_generator.py
File metadata and controls
101 lines (82 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
Script to generate equations for math-eval dataset.
Usage:
python equation_generator.py --input_file <input_file> --output_dir <output_dir> [other flags]
"""
import random
import numpy as np
import argparse
def generate_invertible_matrix(k, coef_min=1, coef_max=10):
"""
Generate a random k x k matrix with integer entries in the range [coef_min, coef_max]
that is invertible (i.e., has a non-zero determinant).
"""
while True:
A = np.array([[random.randint(coef_min, coef_max) for _ in range(k)] for _ in range(k)], dtype=int)
if np.linalg.det(A) != 0:
return A
def format_equation_row(coeffs, var_names, constant):
"""
Format a single equation row with only positive coefficients.
For example, given coeffs = [2, 3] and var_names = ['a', 'b'] with constant 26,
it returns: "2 a + 3 b = 26"
"""
terms = []
for coeff, var in zip(coeffs, var_names):
# Always show the coefficient (even if it's 1) and the variable.
term = f"{coeff} {var}"
terms.append(term)
# Join the terms with ' + ' and then append the constant.
equation = " + ".join(terms) + " = " + str(constant)
return equation
def generate_problem(k, sol_min=1, sol_max=10, coef_min=1, coef_max=10, constant_max=100):
"""
Generate one system of equations with k variables that has a unique integer solution.
Ensures that the computed constants (from A * solution) are not larger than constant_max.
Returns a string in the format:
eq1 , eq2 , ... <sep> a = sol_a , b = sol_b , ...
Note: If constant_max is set too low, the generation process may take many iterations.
"""
# Choose variable names (e.g., ['a', 'b', 'c'] for k=3)
var_names = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'][:k]
# Try generating a valid system until the constants are within the allowed range.
while True:
# Generate a random solution vector.
solution = [random.randint(sol_min, sol_max) for _ in range(k)]
# Generate an invertible coefficient matrix (with only positive entries).
A = generate_invertible_matrix(k, coef_min, coef_max)
# Compute the constants: A * solution.
constants = A.dot(solution)
# Check if the maximum constant is within the allowed bound.
if np.max(constants) <= constant_max:
break
# Format each equation.
equations = [format_equation_row(row, var_names, const)
for row, const in zip(A, constants)]
# Join equations and format the solution string.
equations_str = " , ".join(equations)
solution_str = " , ".join([f"{var} = {val}" for var, val in zip(var_names, solution)])
return equations_str + " <sep> " + solution_str
def generate_problems(num_problems, k, output_file, constant_max=100):
"""
Generate the specified number of problems and write them into output_file.
Each line in the file corresponds to one problem.
"""
with open(output_file, "w") as f:
for _ in range(num_problems):
problem = generate_problem(k, constant_max=constant_max)
f.write(problem + "\n")
def main():
parser = argparse.ArgumentParser(description="Generate equations for math-eval dataset.")
parser.add_argument('--output_file', type=str, required=True, help='Output file for generated equations')
parser.add_argument("--num", type=int, default=10,
help="Number of problems to generate (default: 10)")
parser.add_argument("--vars", type=int, default=2,
help="Number of variables in each system (default: 2; common choices are 2 or 3)")
parser.add_argument("--const_max", type=int, default=100,
help="Maximum allowed constant in the equation (default: 100)")
args = parser.parse_args()
# Generate problems and write to the specified output file.
generate_problems(args.num, args.vars, args.output_file, constant_max=args.const_max)
if __name__ == "__main__":
main()