Java教程

pytorh dataloader 迭代类型数据链式处理分析

本文主要是介绍pytorh dataloader 迭代类型数据链式处理分析,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

https://github.com/wenet-e2e/wenet wenet官方代码,在最新的UIO模式中加入链式处理数据

import time
import random

class Process():
    def __init__(self ,data ,f):
        self.data = data
        self.f = f
    def __iter__(self):
        return self.f(iter(self.data))
# data = [[j + str(i) for i in range(10)] for j in ['a','b', 'c'] ]
data = ['a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9','b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8', 'b9','c0', 'c1', 'c2', 'c3', 'c4', 'c5', 'c6', 'c7', 'c8', 'c9']

def travel(d):
    for i in d:
        yield i

def shuffle(d , sf_size=15):
    buf = []
    for i in d:
        buf.append(i)
        if len(buf) >= sf_size:
            random.shuffle(buf)
            for j in buf:
#                 print('shuffle',j)
                yield j
            buf = []
    for k in buf :
        yield k

def sort(d):
    buf = []
    for i in d:
        buf.append(i)
        if len(buf) >= 5:
            for i in buf:
#                 print('sort' , i )
                yield i 
            buf = []
    for k in buf:
        yield k

def batch(d):
    buf = []
    for i in d:
        buf.append(i)
        if len(buf) >= 4:
            for i in buf:
#                 print('batch' , i )
                yield i 
            buf = []
            
p = Process(data , travel)
p = Process(p , shuffle)
# p = Process(p , sort)
p = Process(p , batch)

for i in p:
    print(i , 'train')
这篇关于pytorh dataloader 迭代类型数据链式处理分析的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!