pakkau/pakkau.py~

302 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import pandas as pd
import math
from functools import reduce
import argparse
import os
import sqlite3
from itertools import chain
model_filename = "model.db"
def genmod():
corpus_path = "./corpus/"
df_list = []
for file in os.listdir(corpus_path):
if file.endswith(".csv"):
df = pd.read_csv(corpus_path+file, header=0, names=['hanji', 'lomaji'])
df_list.append(df)
df = pd.concat(df_list)
df['lomaji'] = df['lomaji'].str.lower()
new_data = []
for index, row in df.iterrows():
hanji = list(filter(lambda x : re.match("[^、();:,。!?「」『』]", x), list(row['hanji'])))
tl = re.split(r'(:?[!?;,.\"\'\(\):]|[-]+|\s+)', row['lomaji'])
tl2 = list(filter(lambda x : re.match(r"([^\(\)^!:?; \'\",.\-\u3000])", x), tl))
new_data.append((hanji, tl2))
if (len(hanji) != len(tl2)):
raise ValueError(f"length of hanji {hanji} is different from romaji {tl2}.")
#model_filename = "model.db"
try:
os.remove(model_filename)
except OSError:
pass
con = sqlite3.connect(model_filename)
cur = con.cursor()
cur.execute("CREATE TABLE pronounce(hanji, lomaji, freq)")
char_to_pronounce = {}
for i in new_data:
hanji = i[0]
lomaji = i[1]
for j in range(len(i[0])):
if not hanji[j] in char_to_pronounce:
char_to_pronounce[hanji[j]] = {lomaji[j] : 1}
elif not lomaji[j] in char_to_pronounce[hanji[j]]:
char_to_pronounce[hanji[j]][lomaji[j]] = 1
else:
char_to_pronounce[hanji[j]][lomaji[j]] += 1
for i in char_to_pronounce.keys():
hanji = char_to_pronounce[i]
for j in hanji.keys():
cur.execute("INSERT INTO pronounce VALUES(?, ?, ?)", (i,j, hanji[j]))
all_chars = char_to_pronounce.keys()
init_freq = {} #詞kap句開始ê字出現次數
cur.execute("CREATE TABLE initial(char, freq)")
for i in new_data:
head_hanji = i[0][0]
if head_hanji in init_freq:
init_freq[head_hanji] += 1
else:
init_freq[head_hanji] = 1
#補字
min_weight = 0.1
for i in all_chars:
if not i in init_freq.keys():
init_freq[i] = 0.1
for i in init_freq.keys():
cur.execute("INSERT INTO initial VALUES(?, ?)", (i, init_freq[i]))
char_transition = {}
cur.execute("CREATE TABLE transition(prev_char, next_char, freq)")
for i in new_data:
hanji = i[0]
for j in range(len(i[0])-1):
this_hanji = hanji[j]
next_hanji = hanji[j+1]
if not this_hanji in char_transition:
char_transition[this_hanji] = {next_hanji : 1}
elif not next_hanji in char_transition[this_hanji]:
char_transition[this_hanji][next_hanji] = 1
else:
char_transition[this_hanji][next_hanji] += 1
for i in char_transition.keys():
next_char = char_transition[i]
for j in next_char.keys():
cur.execute("INSERT INTO transition VALUES(?, ?, ?)", (i, j, next_char[j]))
#get_homophones("lí", cur, con)
con.commit()
con.close()
def get_homophones(pron, cur, con):
homophones_raw = cur.execute("select hanji FROM pronounce where lomaji = ?", (pron, )).fetchall()
homophones = list(map(lambda x: x[0], homophones_raw))
return homophones
def convert(sentences):
splitted = re.split(r'(:?[!?;,.\"\'\(\):])', sentences)
splitted_cleaned = list(filter(lambda x : x != '', splitted))
result = list(map(lambda s : convert_one_sentence(s), splitted_cleaned))
flatten_result = [x for xs in result for xss in xs for x in xss]
result_string = "".join(flatten_result)
print(result_string)
return result_string
def convert_one_sentence(sentence):
full_width = ["", "", "","","","", "", ""]
half_width = ["!", "?", ";", ":", ",", ".", "(", ")"]
if len(sentence) == 1:
for i in range(len(half_width)):
if sentence[0] == half_width[i]:
return [[full_width[i]]]
weight = 2/3
splitted = re.split(r'(--?|\s+)', sentence)
filtered = list(filter(lambda x :not re.match(r'(--?|\s+)', x), splitted))
small_capized = list(map(lambda x : x.lower(), filtered))
print("======", small_capized)
con = sqlite3.connect(model_filename)
cur = con.cursor()
homophones_sequence_raw = list(map(lambda x : get_homophones(x, con, cur), small_capized))
homophones_sequence = [list(map (lambda x : {"char": x,
"prev_char": None,
"prob" : 1}, i)) for i in homophones_sequence_raw]
head_freqs = list(map(lambda x : x[0], cur.execute('''select initial.freq FROM initial
INNER JOIN pronounce ON pronounce.hanji = initial.char
WHERE pronounce.lomaji = ?''', (small_capized[0], )).fetchall()))
return_result = [None] * len(small_capized)
if head_freqs == []:
return_result[0] = filtered[0]
homophones_sequence[0] = [{"char": filtered[0],
"prev_char": None,
"prob" : 1}]
else:
head_freq_total = reduce(lambda x , y : x + y, head_freqs)
for i in homophones_sequence[0]:
i_freq = cur.execute('''select initial.freq FROM initial
WHERE initial.char = ?''', (i['char'])).fetchall()[0][0]
i['prob'] = i_freq / head_freq_total
print(i)
#for i in homophones_sequence[0]:
print("+++++", return_result)
if len(small_capized) == 1:
max_prob = -math.inf
max_prob_char = None
for i in homophones_sequence[0]:
if i['prob'] > max_prob:
max_prob_char = i['char']
max_prob = i['prob']
return_result[0] = max_prob_char
else:
for i in range(1,len(small_capized)):
char_freqs = list(map(lambda x : x[0], cur.execute('''select initial.freq FROM initial
INNER JOIN pronounce ON pronounce.hanji = initial.char
WHERE pronounce.lomaji = ?''', (small_capized[i], )).fetchall()))
if char_freqs == []:
return_result[i] = filtered[i]
homophones_sequence[i] = [{"char": filtered[i],
"prev_char": None,
"prob" : 1}]
prev_char = ""
max_prob = -math.inf
for m in homophones_sequence[i-1]:
if m['prob'] > max_prob:
max_prob = m['prob']
prev_char = m['char']
homophones_sequence[i][0]['prob'] = max_prob
homophones_sequence[i][0]['prev_char'] = prev_char
else:
total_transition_freq = cur.execute('''
SELECT sum(t.freq)
FROM transition as t
INNER JOIN pronounce as p1 ON p1.hanji = t.prev_char
INNER JOIN pronounce as p2 ON p2.hanji = t.next_char
where p2.lomaji = ? and p1.lomaji = ?''',
(small_capized[i], small_capized[i-1])).fetchall()[0][0]
for j in homophones_sequence[i]:
prev_char = None
max_prob = -math.inf
for k in homophones_sequence[i-1]:
k_to_j_freq_raw = cur.execute('''select freq from transition
where prev_char = ? and next_char = ? ''', (k["char"], j["char"])).fetchall()
if k_to_j_freq_raw == []:
den = cur.execute('''
SELECT sum(p.freq)
FROM pronounce as p
inner join pronounce as p2
on p.hanji = p2.hanji where p2.lomaji = ?''', (small_capized[i],)).fetchall()[0][0]#分母
#分子
num = cur.execute(''' SELECT sum(freq) FROM pronounce as p where hanji = ?''', (j["char"],)).fetchall()[0][0]
print("+++", num, den)
k_to_j_freq = num/den * (1-weight)
else:
num = k_to_j_freq_raw[0][0]
don = total_transition_freq
k_to_j_freq =num/don * weight
print("k_to_j_fr", k["char"], j["char"], k_to_j_freq)
if k_to_j_freq * k["prob"] > max_prob:
max_prob = k_to_j_freq * k["prob"]
prev_char = k["char"]
print("~-~_~-~-~-~-", prev_char, j["char"], max_prob)
j["prob"] = max_prob
j["prev_char"] = prev_char
max_prob = -math.inf
current = ""
prev_char = ""
for i in homophones_sequence[len(homophones_sequence)-1]:
if i["prob"] > max_prob:
max_prob = i["prob"]
current = i["char"]
prev_char = i["prev_char"]
print("~tail~~", current)
print(homophones_sequence)
return_result[len(homophones_sequence)-1] = current
for i in range(len(homophones_sequence)-2, -1, -1):
current_ls = list(filter(lambda x : x["char"] == prev_char,
homophones_sequence[i]))
print(prev_char)
return_result[i] = prev_char
current = current_ls[0]["char"]
prev_char = current_ls[0]["prev_char"]
print(return_result)
return return_result
def poj_to_tl(sentence):
return sentence
parser = argparse.ArgumentParser()
parser.add_argument('--genmod', help='generate the model', action='store_true',
required=False,)
parser.add_argument('sentence', metavar='SENTENCE', nargs='?',
help='the sentence to be converted')
parser.add_argument('--form', metavar='FORM', choices=["poj", "tl"], nargs=1,
default=['poj'],
help='the orthography to be used (poj or tl). Default is poj.')
args = parser.parse_args()
if args.genmod == True:
genmod()
elif args.sentence != None:
if args.form == ['poj']:
sentence = poj_to_tl(args.sentence)
convert(sentence)
else:
convert(args.sentence)
else:
parser.print_help()