diff --git a/shogi/CSA.py b/shogi/CSA.py index 7485c08..db1a4db 100644 --- a/shogi/CSA.py +++ b/shogi/CSA.py @@ -65,7 +65,27 @@ class Parser: def parse_file(path): with open(path) as f: return Parser.parse_str(f.read()) - + + @staticmethod + def parse_comment(comment, sfen): + board = shogi.Board(sfen) + # parse comment to get pv and value + m = re.match(r"'\*\*\s(\-?[\d]+)\s?(.*)", comment) + # m is pv and value + if m is not None: + value = m.groups()[0] + move_csa_list = m.groups()[1].split(" ") + move_str_list = [] + for move_csa in move_csa_list: + if move_csa == "": + continue + (color, move) = Parser.parse_move_str(move_csa, board) + board.push(shogi.Move.from_usi(move)) + move_str_list.append(move) + return int(value), move_str_list, None + else: + return None, None, comment + @staticmethod def parse_str(csa_str): line_no = 1 @@ -77,11 +97,31 @@ def parse_str(csa_str): current_turn_str = None moves = [] lose_color = None + move_start = False + values = [] + pvs = [] + comments = [] + temp_values = [] + temp_pvs = [] + temp_comments = [] + for line in csa_str.split('\n'): if line == '': pass elif line[0] == "'": - pass + if move_start: + try: + value, pv, comment = Parser.parse_comment(line, board.sfen()) + if value is not None: + temp_values.append(value) + if pv is not None: + temp_pvs.append(pv) + if comment is not None: + temp_comments.append(comment) + except Exception: + # skip the invalid comments + pass + elif line[0] == 'V': # Currently just ignoring version pass @@ -91,7 +131,7 @@ def parse_str(csa_str): # Currently just ignoring information pass elif line[0] == 'P': - position_lines.append(line) + position_lines.append(line) elif line[0] in COLOR_SYMBOLS: if len(line) == 1: current_turn_str = line[0] @@ -101,6 +141,13 @@ def parse_str(csa_str): (color, move) = Parser.parse_move_str(line, board) moves.append(move) board.push(shogi.Move.from_usi(move)) + move_start = True + pvs.append([v for v in temp_pvs]) + values.append([v for v in temp_values]) + comments.append([v for v in temp_comments]) + temp_pvs = [] + temp_values = [] + temp_comments = [] elif line[0] == 'T': # Currently just ignoring consumed time pass @@ -116,6 +163,9 @@ def parse_str(csa_str): lose_color = shogi.BLACK elif line == '%-ILLEGAL_ACTION': lose_color = shogi.WHITE + pvs.append([v for v in temp_pvs]) + values.append([v for v in temp_values]) + comments.append([v for v in temp_comments]) # TODO: Support %MATTA etc. break @@ -144,7 +194,10 @@ def parse_str(csa_str): 'names': names, 'sfen': sfen, 'moves': moves, - 'win': win + 'win': win, + 'values' : values[1:], + 'pvs' : pvs[1:], + 'comments' : comments[1:], } # NOTE: for future support of multiple matches return [summary] diff --git a/tests/board_test.py b/tests/board_test.py index c4d5f93..5ed2a13 100644 --- a/tests/board_test.py +++ b/tests/board_test.py @@ -205,7 +205,6 @@ def test_issue_17(self): def test_usi_command(self): board = shogi.Board() - board.push_usi_position_cmd("position startpos moves 7g7f") self.assertEqual(board.sfen(), 'lnsgkgsnl/1r5b1/ppppppppp/9/9/2P6/PP1PPPPPP/1B5R1/LNSGKGSNL w - 2') board.push_usi_position_cmd("position sfen ln1g3+Rl/1ks4s1/pp1gppbpp/2p3N2/9/5P1P1/PPPP1S1bP/2K1R1G2/LNSG3NL w 4p 42") diff --git a/tests/csa_test.py b/tests/csa_test.py index 47ee918..2f0970d 100644 --- a/tests/csa_test.py +++ b/tests/csa_test.py @@ -56,15 +56,18 @@ + '指し手と消費時間(optional) +2726FU +'** 22 -8384FU T12 -3334FU T6 +'** 0 +2625FU -8384FU +6978KI -8485FU +3938GI -7172GI +9796FU +7776FU +'using csa format is a kind of torment! %TORYO '--------------------------------------------------------- """ -TEST_CSA_SUMMARY = {'moves': ['2g2f', '3c3d', '7g7f'], 'sfen': 'lnsgkgsnl/1r5b1/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b - 1', 'names': ['NAKAHARA', 'YONENAGA'], 'win': 'b'} +TEST_CSA_SUMMARY = {'moves': ['2g2f', '3c3d', '7g7f'], 'sfen': 'lnsgkgsnl/1r5b1/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b - 1', 'names': ['NAKAHARA', 'YONENAGA'], 'win': 'b', 'values': [[22], [0], []], 'comments' : ['', '', ["'using csa format is a kind of torment!"]], 'pvs': [[['8c8d']], [['2f2e', '8c8d', '6i7h', '8d8e', '3i3h', '7a7b', '9g9f']], []], 'comments': [[], [], ["'using csa format is a kind of torment!"]]} TEST_CSA_WITH_PI = ''' V2.2 @@ -85,15 +88,19 @@ 'moves': ['7g7f', '8c8d'], 'sfen': 'lnsgkgsnl/9/ppppppppp/9/9/9/PPPPPPPPP/1B5R1/LNSGKGSNL b - 1', 'names': ['先手', '後手'], - 'win': 'w' + 'win': 'w', + 'comments' : [[],[]], + 'pvs' : [[], []], + 'values' : [[], []], } class ParserTest(unittest.TestCase): - def parse_str_test(self): + def test_parse_str_test(self): result = CSA.Parser.parse_str(TEST_CSA) + print(result[0]) self.assertEqual(result[0], TEST_CSA_SUMMARY) - def parse_str_test_with_PI(self): + def test_parse_str_test_with_PI(self): result = CSA.Parser.parse_str(TEST_CSA_WITH_PI) self.assertEqual(result[0], TEST_CSA_SUMMARY_WITH_PI)