mpmath_special_functions_test_generator.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #!/usr/bin/env python3
  2. import mpmath as mp
  3. import mpmath_riccati_bessel as mrb
  4. import mpmath_input_arguments as mia
  5. import os.path
  6. class TestData:
  7. def __init__(self, list_to_parse, filetype):
  8. self.filetype = filetype
  9. if self.filetype == 'c++':
  10. self.cpp_parse(list_to_parse)
  11. else:
  12. raise NotImplementedError("Only C++ files *.hpp parsing was implemented")
  13. def cpp_parse(self, list_to_parse):
  14. self.comment = list_to_parse[0]
  15. print(self.comment)
  16. if self.comment[:2] != '//': raise ValueError('Not a comment')
  17. self.typeline = list_to_parse[1]
  18. if 'std::vector' not in self.typeline: raise ValueError('Unexpected C++ container')
  19. self.testname = list_to_parse[2]
  20. self.opening = list_to_parse[3]
  21. if self.opening != '= {': raise ValueError('For C++ we expect opeing with = {');
  22. self.ending = list_to_parse[-1]
  23. if self.ending != '};': raise ValueError('For C++ we expect closing };')
  24. self.evaluated_data = list_to_parse[4:-1]
  25. def get_string(self):
  26. out_sting = self.comment + '\n' + self.typeline + '\n' + self.testname + '\n' + self.opening + '\n'
  27. for result in self.evaluated_data:
  28. out_sting += result + '\n'
  29. out_sting += self.ending + '\n'
  30. return out_sting
  31. class UpdateSpecialFunctionsEvaluations:
  32. def __init__(self, filename='default_out.hpp', complex_arguments=[],
  33. output_dps=16, max_num_elements_of_nlist=51):
  34. self.evaluated_data = []
  35. self.test_setup = []
  36. self.filename = filename
  37. self.read_evaluated_data()
  38. self.complex_arguments = complex_arguments
  39. self.output_dps = output_dps
  40. self.max_num_elements_of_nlist = max_num_elements_of_nlist
  41. def read_evaluated_data(self):
  42. self.filetype = 'undefined'
  43. if self.filename.endswith('.hpp'):
  44. self.filetype = 'c++'
  45. if self.filename.endswith('.f90'):
  46. self.filetype = 'fortran'
  47. if not os.path.exists(self.filename):
  48. print("WARNING! Found no data file:", self.filename)
  49. return
  50. with open(self.filename, 'r') as in_file:
  51. content = in_file.readlines()
  52. content = [x.strip() for x in content]
  53. while '' in content:
  54. record_end_index = content.index('')
  55. new_record = content[:record_end_index]
  56. content = content[record_end_index + 1:]
  57. self.add_record(new_record)
  58. self.add_record(content)
  59. def add_record(self, new_record):
  60. if len(new_record) == 0: return
  61. if len(new_record) < 6: raise ValueError('Not enough lines in record:', new_record)
  62. self.evaluated_data.append(TestData(new_record, self.filetype))
  63. def get_file_content(self):
  64. self.evaluated_data.sort(key=lambda x: x.testname)#, reverse=True)
  65. out_string = ''
  66. for record in self.evaluated_data:
  67. out_string += record.get_string() + '\n'
  68. return out_string[:-1]
  69. def remove(self, testname):
  70. for i, result in enumerate(self.evaluated_data):
  71. if result.testname == testname:
  72. del self.evaluated_data[i]
  73. def get_n_list(self, z, max_number_of_elements=10):
  74. nmax = mrb.LeRu_cutoff(z)
  75. factor = nmax ** (1 / (max_number_of_elements - 2))
  76. n_list = [int(factor ** i) for i in range(max_number_of_elements - 1)]
  77. n_list.append(0)
  78. n_set = set(n_list)
  79. return sorted(n_set)
  80. def get_test_data_nlist(self, z_record, output_dps, n, func):
  81. x = str(z_record[0])
  82. mr = str(z_record[1][0])
  83. mi = str(z_record[1][1])
  84. z_str = ''
  85. try:
  86. z = mp.mpf(x) * mp.mpc(mr, mi)
  87. D1nz = func(n, z)
  88. z_str = ('{{' +
  89. mp.nstr(z.real, output_dps * 2) + ',' +
  90. mp.nstr(z.imag, output_dps * 2) + '},' +
  91. str(n) + ',{' +
  92. mp.nstr(D1nz.real, output_dps) + ',' +
  93. mp.nstr(D1nz.imag, output_dps) + '},' +
  94. mp.nstr(mp.fabs(D1nz.real * 10 ** -output_dps), 2) + ',' +
  95. mp.nstr(mp.fabs(D1nz.imag * 10 ** -output_dps), 2) +
  96. '},')
  97. except:
  98. pass
  99. return z_str
  100. def get_test_data(self, Du_test, output_dps, max_num_elements_of_n_list, func, funcname):
  101. output_list = ['// complex(z), n, complex(D1(n,z)), abs_err_real, abs_err_imag',
  102. 'std::vector< std::tuple< std::complex<double>, int, std::complex<double>, double, double > >',
  103. str(funcname)+'_test_' + str(output_dps) + 'digits','= {']
  104. for z_record in Du_test:
  105. x = str(z_record[0])
  106. mr = str(z_record[1][0])
  107. mi = str(z_record[1][1])
  108. mp.mp.dps = 20
  109. z = mp.mpf(x) * mp.mpc(mr, mi)
  110. n_list = self.get_n_list(z, max_num_elements_of_n_list)
  111. print(z, n_list)
  112. for n in n_list:
  113. mp.mp.dps = 20
  114. old_z_string = self.get_test_data_nlist(z_record, output_dps, n, func)
  115. mp.mp.dps = 37
  116. new_z_string = self.get_test_data_nlist(z_record, output_dps, n, func)
  117. while old_z_string != new_z_string:
  118. new_dps = int(mp.mp.dps * 1.41)
  119. if new_dps > 300: break
  120. mp.mp.dps = new_dps
  121. print("New dps = ", mp.mp.dps, 'n =', n, ' (max ',n_list[-1],') for z =', z, ' ', end='')
  122. old_z_string = new_z_string
  123. new_z_string = self.get_test_data_nlist(z_record, output_dps, n, func)
  124. if new_z_string != '':
  125. output_list.append(new_z_string)
  126. else:
  127. break
  128. output_list.append('};')
  129. return output_list
  130. def run_test(self, func, funcname):
  131. out_list_result = self.get_test_data(mia.complex_arguments, self.output_dps,
  132. self.max_num_elements_of_nlist,
  133. func, funcname)
  134. testname = str(funcname)+'_test_' + str(self.output_dps) + 'digits'
  135. self.remove(testname)
  136. self.add_record(out_list_result)
  137. def main():
  138. sf_evals = UpdateSpecialFunctionsEvaluations(filename='test_spec_functions_data.hpp',
  139. complex_arguments=mia.complex_arguments,
  140. output_dps=16, max_num_elements_of_nlist=51)
  141. # output_dps=3, max_num_elements_of_nlist=3)
  142. # sf_evals.run_test(mrb.D1, 'D1')
  143. with open(sf_evals.filename, 'w') as out_file:
  144. out_file.write(sf_evals.get_file_content())
  145. main()