import os
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
name = MPI.Get_processor_name()
 
import csv
source=[]
target=[]
with open("test.csv","r",encoding="utf-8") as f:
    cdata=csv.reader(f,delimiter=",")
    for i,items in enumerate(cdata):
        if i==0:
            continue
        source.append(items[0])
        target.append(items[1])
 
data_lst=[]
for src,tgt in zip(source,target):
    data_lst.append((src,tgt))
 
# http://muscle199x.blog.fc2.com/blog-entry-70.html?sp
def split_list(lst, n):
    list_size = len(lst)
    a = list_size // n
    b = list_size % n
    return [lst[i * a + (i if i < b else b):(i + 1) * a + (i + 1 if i < b else b)] for i in range(n)]
 
from transformers import MT5ForConditionalGeneration, MT5Model, MT5EncoderModel, T5Tokenizer
 
tokenizer = T5Tokenizer.from_pretrained("google/mt5-base")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
if rank == 0:
    data = split_list(data_lst, size)
else:
    data = None
 
def calc(my_data):
    res=[]
 
    for src,tgt in my_data:
        inputs = tokenizer.encode("summarize: " + src, return_tensors='pt', max_length=512, truncation=True)
        summary_ids = model.generate(inputs, max_length=50, min_length=10, length_penalty=5., num_beams=2)
        summary = tokenizer.decode(summary_ids[0])
        res.append((src,tgt,summary))
    return res
 
my_data = comm.scatter(data, root=0)
 
for r in range(comm.size):
    if rank == r:
        print("scatter.[%d] %d" % (rank, len(my_data)))
    comm.Barrier()
 
my_res = calc(my_data)
def flatten(l):
    try:
        return ([item for sublist in l for item in sublist])
    except:
        import traceback
        traceback.print_exc()
        print(l)
 
# gather
res = MPI.COMM_WORLD.gather(my_res, root=0)
 
if rank == 0:
    res=flatten(res)
    for src,tgt,summary in res:
        print(f"source:{src}")
        print(f"target:{tgt}")
        print(f"summary:{summary}")
        print("-----")