使用Go與Tensorflow做影像辨識

程式解說

  • /api/main.go 宣告部分
// 宣告Global變數,供後續使用
var (
    graph  *tf.Graph   // tensorflow graph
    labels []string    // 學習檔的辨識字串
)
  • 撰寫载入程式
func loadModel() error {
    // Load inception model
    model, err := ioutil.ReadFile("/model/tensorflow_inception_graph.pb")
    if err != nil {
        return err
    }
    graph = tf.NewGraph()
    if err := graph.Import(model, ""); err != nil {
        return err
    }
    // Load labels
    labelsFile, err := os.Open("/model/imagenet_comp_graph_label_strings.txt")
    if err != nil {
        return err
    }
    defer labelsFile.Close()
  scanner := bufio.NewScanner(labelsFile)
  // Labels are separated by newlines
    for scanner.Scan() {
        labels = append(labels, scanner.Text())
    }
    if err := scanner.Err(); err != nil {
        return err
    }
    return nil
}
  • 回傳資訊回client端
// 錯誤發生時
func responseError(w http.ResponseWriter, message string, code int) {
    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(code)
    json.NewEncoder(w).Encode(map[string]string{"error": message})
}

// 正常回覆
func responseJSON(w http.ResponseWriter, data interface{}) {
    w.Header().Set("Content-Type", "application/json")
    json.NewEncoder(w).Encode(data)
}
  • 主程式 main.go
func main() {  // 主程式
    if err := loadModel(); err != nil {
        log.Fatal(err)
        return
    }

    r := httprouter.New()
    r.POST("/recognize", recognizeHandler)   // 辨識
    log.Fatal(http.ListenAndServe(":8080", r))
}
  • 收到POST來的檔案後,開始辨識(recognizeHandler)
func recognizeHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
    // Read image
    imageFile, header, err := r.FormFile("image")
    // Will contain filename and extension
    imageName := strings.Split(header.Filename, ".")
    if err != nil {
        responseError(w, "Could not read image", http.StatusBadRequest)
        return
    }
    defer imageFile.Close()
    var imageBuffer bytes.Buffer
    // Copy image data to a buffer
    io.Copy(&imageBuffer, imageFile)

  tensor, err := makeTensorFromImage(&imageBuffer, imageName[:1][0])
  if err != nil {
     responseError(w, "Invalid image", http.StatusBadRequest)
     return
  }

  session, err := tf.NewSession(graph, nil)
  if err != nil {
     log.Fatal(err)
  }
  defer session.Close()
  // output[0].Value() 會包含可能的標籤及機率
  output, err := session.Run(
     map[tf.Output]*tf.Tensor{
        graph.Operation("input").Output(0): tensor,
     },
     []tf.Output{
        graph.Operation("output").Output(0),
     }, nil
   )
   if err != nil {
     responseError(w, "Could not run inference", http.StatusInternalServerError)
     return
  }
}
  • makeTensorFromImage 辨識函數
func makeTensorFromImage(imageBuffer *bytes.Buffer, imageFormat string) (*tf.Tensor, error) {
    tensor, err := tf.NewTensor(imageBuffer.String())
    if err != nil {
        return nil, err
    }
    graph, input, output, err := makeTransformImageGraph(imageFormat)
    if err != nil {
        return nil, err
    }
    session, err := tf.NewSession(graph, nil)
    if err != nil {
        return nil, err
    }
    defer session.Close()
    normalized, err := session.Run(
        map[tf.Output]*tf.Tensor{input: tensor},
        []tf.Output{output},
        nil)
    if err != nil {
        return nil, err
    }
    return normalized[0], nil
}
  • makeTransformImageGraph裁切影像大小為 224×224 並正規化像素
func makeTransformImageGraph(imageFormat string) (graph *tf.Graph, input, output tf.Output, err error) {
    const (
        H, W  = 224, 224
        Mean  = float32(117)
        Scale = float32(1)
    )
    s := op.NewScope()
    input = op.Placeholder(s, tf.String)
    // Decode PNG or JPEG
    var decode tf.Output
    if imageFormat == "png" {
        decode = op.DecodePng(s, input, op.DecodePngChannels(3))
    } else {
        decode = op.DecodeJpeg(s, input, op.DecodeJpegChannels(3))
    }
    // Div and Sub perform (value-Mean)/Scale for each pixel
    output = op.Div(s,
        op.Sub(s,
            // Resize to 224x224 with bilinear interpolation
            op.ResizeBilinear(s,
                // Create a batch containing a single image
                op.ExpandDims(s,
                    // Use decoded pixel values
                    op.Cast(s, decode, tf.Float),
                    op.Const(s.SubScope("make_batch"), int32(0))),
                op.Const(s.SubScope("size"), []int32{H, W})),
            op.Const(s.SubScope("mean"), Mean)),
        op.Const(s.SubScope("scale"), Scale))
    graph, err = s.Finalize()
    return graph, input, output, err
}
  • 回到 main.go,宣告輸出結構
type ClassifyResult struct {
    Filename string        `json:"filename"`
    Labels   []LabelResult `json:"labels"`
}

type LabelResult struct {
    Label       string  `json:"label"`
    Probability float32 `json:"probability"`
}
  • findBestLabels 找出機率最高的標籤
type ByProbability []LabelResult
func (a ByProbability) Len() int           { return len(a) }
func (a ByProbability) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
func (a ByProbability) Less(i, j int) bool { return a[i].Probability > a[j].Probability }

func findBestLabels(probabilities []float32) []LabelResult {
    // Make a list of label/probability pairs
    var resultLabels []LabelResult
    for i, p := range probabilities {
        if i >= len(labels) {
            break
        }
        resultLabels = append(resultLabels, LabelResult{Label: labels[i], Probability: p})
    }
    // Sort by probability
    sort.Sort(ByProbability(resultLabels))
    // Return top 5 labels
    return resultLabels[:5]
}
  • 執行 docker
docker-compose -f docker-compose.yaml up -d --build
  • 執行辨識
curl localhost:8080/recognize -F '[email protected]/cat.jpg'
  • 執行結果
{
  "filename": "cat.jpg",
  "labels": [
    { "label": "Egyptian cat", "probability": 0.39229771 },
    { "label": "weasel", "probability": 0.19872947 },
    { "label": "Arctic fox", "probability": 0.14527217 },
    { "label": "tabby", "probability": 0.062454574 },
    { "label": "kit fox", "probability": 0.043656528 }
  ]
}

相關工具

參考工具

本篇文章原始碼