Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
230 views
in Technique[技术] by (71.8m points)

python - How can implement make_one_shot_iterator() function of tensorflow 1.0 in tensorflow 2.0 version?

In the sample one of this tutorial, it use tensorflow 1.0. I want implement that code in the tensorflow 2.0.

there is two part of code:

import tensorflow as tf
from time import sleep
from time import time

# data generator
def py_gen(gen_name):
    gen_name = gen_name.decode('utf-8')
    for num in range(20):
        sleep(0.3)
        yield '{} yields {}'.format(gen_name, num)


# model operation
def model(data):
    sleep(0.1)

and

Dataset = tf.data.Dataset
name = 'Gen_0'
ds = Dataset.from_generator(py_gen,
                            output_types=(tf.string),
                            args=(name,))
data_tf = ds.make_one_shot_iterator().get_next()

and the run is:

def run_session(data_tf):
    with tf.Session() as sess:
        while True:
            try:
                t1 = time()
                data_py = sess.run(data_tf)
                t2 = time()
                t = t2 - t1
                model(data_tf)
                msg = 'elapsed time: {:.3f}, {}'.format(t, data_py.decode('utf-8'))
                print(msg)
            except tf.errors.OutOfRangeError:
                print('data generator(s) are exhausted')
                break

the make_one_shot_iterator() function did not implimented in tensorflow 2.0, but there is tensorflow.v1.data.make_one_shot_iterator that can use this function.

But I want impliment this only with tf 2.0 and don't use tensorflow.v1.data.make_one_shot_iterator.

bow can I do this?


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

If you are planning to iteration on a dataset, you can just do

iterator = iter(ds)

I have modified your code to do that. FYI, I just removed that "decode" because I did not understand what that was supposed to do

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
from matplotlib import pyplot as plt
import numpy as np
import tensorflow_hub as hub
from time import sleep
from time import time

# data generator
def py_gen(gen_name):
    gen_name = gen_name.decode('utf-8')
    for num in range(20):
        sleep(0.3)
        yield '{} yields {}'.format(gen_name, num)


# model operation
def model(data):
    sleep(0.1)


def run_session(d):
        while True:
            try:
                t1 = time()
                data_py = d.get_next()
                t2 = time()
                t = t2 - t1
                model(ds)
                msg = 'elapsed time: {:.3f}'.format(t)
                print(msg)
            except tf.errors.OutOfRangeError:
                print('data generator(s) are exhausted')
                break                            

Dataset = tf.data.Dataset
name = 'Gen_0'
ds = Dataset.from_generator(py_gen,
                            output_types=(tf.string),
                            args=(name,))

it = iter(ds)
run_session(it)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...