add guesslang tf model
This commit is contained in:
parent
1ca24c58e5
commit
4ee081df66
|
@ -1 +1,2 @@
|
||||||
/target
|
/target
|
||||||
|
venv
|
File diff suppressed because it is too large
Load Diff
|
@ -10,6 +10,7 @@ actix-web = "4"
|
||||||
silicon = { git = "https://github.com/watzon/silicon.git" }
|
silicon = { git = "https://github.com/watzon/silicon.git" }
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
serde = { version = "1.0.130", features = ["derive"] }
|
serde = { version = "1.0.130", features = ["derive"] }
|
||||||
|
serde_json = "1.0.107"
|
||||||
structopt = "0.3.26"
|
structopt = "0.3.26"
|
||||||
image = "0.24.7"
|
image = "0.24.7"
|
||||||
anyhow = "1.0.75"
|
anyhow = "1.0.75"
|
||||||
|
@ -19,3 +20,5 @@ font-kit = "0.11.0"
|
||||||
reqwest = "0.11.22"
|
reqwest = "0.11.22"
|
||||||
hyperpolyglot = "0.1.7"
|
hyperpolyglot = "0.1.7"
|
||||||
tempfile = "3.8.0"
|
tempfile = "3.8.0"
|
||||||
|
tensorflow = "0.17.0"
|
||||||
|
clap = { version = "4.4.7", features = ["derive"] }
|
|
@ -1,7 +1,44 @@
|
||||||
|
FROM debian:buster-slim as tensorflow
|
||||||
|
|
||||||
|
WORKDIR /usr/src/build
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
gnupg \
|
||||||
|
python3 \
|
||||||
|
python3-dev \
|
||||||
|
python3-pip \
|
||||||
|
python3-numpy \
|
||||||
|
llvm \
|
||||||
|
clang
|
||||||
|
|
||||||
|
RUN pip3 install wheel packaging requests opt_einsum
|
||||||
|
RUN pip3 install keras_preprocessing --no-deps
|
||||||
|
|
||||||
|
# Install bazel
|
||||||
|
RUN wget https://github.com/bazelbuild/bazelisk/releases/download/v1.18.0/bazelisk-linux-amd64
|
||||||
|
RUN chmod +x bazelisk-linux-amd64
|
||||||
|
RUN mv bazelisk-linux-amd64 /usr/local/bin/bazel
|
||||||
|
|
||||||
|
# Install tensorflow
|
||||||
|
RUN git clone https://github.com/tensorflow/tensorflow \
|
||||||
|
&& cd tensorflow \
|
||||||
|
&& git checkout v2.5.0
|
||||||
|
RUN cd tensorflow && ./configure
|
||||||
|
RUN cd tensorflow && bazel build --compilation_mode=opt --copt=-march=native --jobs=12 tensorflow:libtensorflow.so
|
||||||
|
|
||||||
FROM rust:1.73.0-buster as builder
|
FROM rust:1.73.0-buster as builder
|
||||||
|
|
||||||
WORKDIR /usr/src/app
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
|
# Copy tensorflow shared libraries from tensorflow image
|
||||||
|
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow.so* /usr/local/lib/
|
||||||
|
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so* /usr/local/lib/
|
||||||
|
|
||||||
|
RUN ldconfig
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
@ -30,6 +67,8 @@ RUN bash ./download_nerd_fonts.sh
|
||||||
|
|
||||||
FROM debian:buster-slim
|
FROM debian:buster-slim
|
||||||
|
|
||||||
|
WORKDIR /usr/src/app
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
libssl-dev \
|
libssl-dev \
|
||||||
|
@ -41,8 +80,15 @@ RUN apt-get update && apt-get install -y \
|
||||||
COPY --from=fonts /data/fonts/nerd_fonts/* /usr/share/fonts/truetype/
|
COPY --from=fonts /data/fonts/nerd_fonts/* /usr/share/fonts/truetype/
|
||||||
RUN fc-cache -fv
|
RUN fc-cache -fv
|
||||||
|
|
||||||
# Copy binary
|
# Copy binary abd tensorflow model files
|
||||||
COPY --from=builder /usr/src/app/target/release/inkify /usr/local/bin/inkify
|
COPY --from=builder /usr/src/app/target/release/inkify /usr/src/app/
|
||||||
|
COPY --from=builder /usr/src/app/src/tensorflow /usr/src/app/tensorflow/
|
||||||
|
|
||||||
|
# Copy tensorflow shared libraries from tensorflow image
|
||||||
|
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow.so* /usr/local/lib/
|
||||||
|
COPY --from=tensorflow /usr/src/build/tensorflow/bazel-bin/tensorflow/libtensorflow_framework.so* /usr/local/lib/
|
||||||
|
|
||||||
|
RUN ldconfig
|
||||||
|
|
||||||
ARG PORT=8080
|
ARG PORT=8080
|
||||||
ARG HOST=0.0.0.0
|
ARG HOST=0.0.0.0
|
||||||
|
@ -53,4 +99,4 @@ ENV HOST=$HOST
|
||||||
EXPOSE $PORT
|
EXPOSE $PORT
|
||||||
|
|
||||||
# Run
|
# Run
|
||||||
ENTRYPOINT ["/usr/local/bin/inkify"]
|
CMD ["/usr/src/app/inkify", "--tensorflow-model-dir", "/usr/src/app/tensorflow"]
|
|
@ -1,17 +1,17 @@
|
||||||
use anyhow::Error;
|
use anyhow::Error;
|
||||||
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
|
use silicon::formatter::{ImageFormatter, ImageFormatterBuilder};
|
||||||
use silicon::utils::{Background, ShadowAdder};
|
use silicon::utils::{Background, ShadowAdder};
|
||||||
use std::io::Write;
|
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use syntect::highlighting::{Theme, ThemeSet};
|
use syntect::highlighting::{Theme, ThemeSet};
|
||||||
use syntect::parsing::{SyntaxReference, SyntaxSet};
|
use syntect::parsing::{SyntaxReference, SyntaxSet};
|
||||||
|
use tensorflow::{Graph, SavedModelBundle, SessionOptions, Tensor};
|
||||||
|
|
||||||
use crate::rgba::{ImageRgba, Rgba};
|
use crate::rgba::{ImageRgba, Rgba};
|
||||||
|
|
||||||
type FontList = Vec<(String, f32)>;
|
type FontList = Vec<(String, f32)>;
|
||||||
type Lines = Vec<u32>;
|
type Lines = Vec<u32>;
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
#[derive(Debug, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
/// Background image URL
|
/// Background image URL
|
||||||
pub background_image: Option<Vec<u8>>,
|
pub background_image: Option<Vec<u8>>,
|
||||||
|
@ -72,6 +72,12 @@ pub struct Config {
|
||||||
|
|
||||||
/// The syntax highlight theme. It can be a theme name or path to a .tmTheme file.
|
/// The syntax highlight theme. It can be a theme name or path to a .tmTheme file.
|
||||||
pub theme: String,
|
pub theme: String,
|
||||||
|
|
||||||
|
#[serde(skip_deserializing)]
|
||||||
|
pub tf_model_graph: Option<Graph>,
|
||||||
|
|
||||||
|
#[serde(skip_deserializing)]
|
||||||
|
pub tf_model: Option<SavedModelBundle>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
@ -97,9 +103,25 @@ impl Config {
|
||||||
shadow_offset_x: 0,
|
shadow_offset_x: 0,
|
||||||
tab_width: 4,
|
tab_width: 4,
|
||||||
theme: "Dracula".to_owned(),
|
theme: "Dracula".to_owned(),
|
||||||
|
tf_model_graph: None,
|
||||||
|
tf_model: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn load_tensorflow_model(&mut self, export_dir: &str) {
|
||||||
|
let mut graph = Graph::new();
|
||||||
|
let model = match SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir) {
|
||||||
|
Ok(model) => model,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Failed to load TensorFlow model: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.tf_model = Some(model);
|
||||||
|
self.tf_model_graph = Some(graph);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn language<'a>(&self, ps: &'a SyntaxSet) -> Result<&'a SyntaxReference, Error> {
|
pub fn language<'a>(&self, ps: &'a SyntaxSet) -> Result<&'a SyntaxReference, Error> {
|
||||||
let language = match &self.language {
|
let language = match &self.language {
|
||||||
Some(language) => ps
|
Some(language) => ps
|
||||||
|
@ -108,21 +130,52 @@ impl Config {
|
||||||
None => {
|
None => {
|
||||||
let first_line = self.code.lines().next().unwrap_or_default();
|
let first_line = self.code.lines().next().unwrap_or_default();
|
||||||
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
|
ps.find_syntax_by_first_line(first_line).unwrap_or_else(|| {
|
||||||
// hyperpolyglot requires a file, so we need to create a temp file
|
// Try using tensorflow to detect the language
|
||||||
let mut temp_file = tempfile::NamedTempFile::new().unwrap();
|
let input_data = Tensor::new(&[1]).with_values(&[self.code.clone()]).unwrap();
|
||||||
write!(temp_file, "{}", self.code).unwrap();
|
self.predict_language_with_tensorflow(ps, input_data)
|
||||||
let language = hyperpolyglot::detect(temp_file.path()).unwrap();
|
.unwrap_or_else(|_| ps.find_syntax_by_token("log").unwrap())
|
||||||
match language {
|
|
||||||
Some(language) => ps.find_syntax_by_token(language.language()).unwrap(),
|
|
||||||
None => ps.find_syntax_by_token("log").unwrap(),
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
Ok(language)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn predict_language_with_tensorflow<'a>(&self, ps: &'a SyntaxSet, input_data: Tensor<String>) -> Result<&'a SyntaxReference, Error> {
|
||||||
|
if self.tf_model_graph.is_none() || self.tf_model.is_none() {
|
||||||
|
return Err(Error::msg("TensorFlow model not loaded"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let graph = self.tf_model_graph.as_ref().unwrap();
|
||||||
|
let model = self.tf_model.as_ref().unwrap();
|
||||||
|
let mut args = tensorflow::SessionRunArgs::new();
|
||||||
|
|
||||||
|
let input_tensor = graph.operation_by_name_required("Placeholder")?;
|
||||||
|
|
||||||
|
let output_tensor_scores = graph.operation_by_name_required("head/predictions/probabilities")?;
|
||||||
|
|
||||||
|
let output_tensor_classes = graph.operation_by_name_required("head/Tile")?;
|
||||||
|
|
||||||
|
args.add_feed(&input_tensor, 0, &input_data);
|
||||||
|
let output_token_scores = args.request_fetch(&output_tensor_scores, 0);
|
||||||
|
let output_token_classes = args.request_fetch(&output_tensor_classes, 0);
|
||||||
|
|
||||||
|
model.session.run(&mut args)?;
|
||||||
|
|
||||||
|
let scores: Tensor<f32> = args.fetch(output_token_scores)?;
|
||||||
|
|
||||||
|
let classes: Tensor<String> = args.fetch(output_token_classes)?;
|
||||||
|
|
||||||
|
// Find the index of the highest score
|
||||||
|
let max_index = scores.iter().enumerate().max_by(|a, b| a.1.partial_cmp(b.1).unwrap()).unwrap().0;
|
||||||
|
|
||||||
|
|
||||||
|
let language = &classes[max_index];
|
||||||
|
let language = ps.find_syntax_by_token(language).unwrap();
|
||||||
|
|
||||||
Ok(language)
|
Ok(language)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn theme(&self, ts: &ThemeSet) -> Result<Theme, Error> {
|
pub fn theme(&self, ts: &ThemeSet) -> Result<Theme, Error> {
|
||||||
if let Some(theme) = ts.themes.get(&self.theme) {
|
if let Some(theme) = ts.themes.get(&self.theme) {
|
||||||
Ok(theme.clone())
|
Ok(theme.clone())
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
extern crate anyhow;
|
extern crate anyhow;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
|
use actix_web::{get, web, App, HttpResponse, HttpServer, Responder};
|
||||||
use anyhow::Error;
|
use anyhow::Error;
|
||||||
use lazy_static::lazy_static;
|
use lazy_static::lazy_static;
|
||||||
|
@ -28,6 +29,13 @@ macro_rules! unwrap_or_return {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct CliArgs {
|
||||||
|
#[arg(short, long)]
|
||||||
|
tensorflow_model_dir: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_font_str(s: &str) -> Vec<(String, f32)> {
|
fn parse_font_str(s: &str) -> Vec<(String, f32)> {
|
||||||
let mut result = vec![];
|
let mut result = vec![];
|
||||||
for font in s.split(';') {
|
for font in s.split(';') {
|
||||||
|
@ -144,6 +152,7 @@ async fn fonts() -> impl Responder {
|
||||||
|
|
||||||
#[get("/generate")]
|
#[get("/generate")]
|
||||||
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||||
|
let args = CliArgs::parse();
|
||||||
let ha = &*HIGHLIGHTING_ASSETS;
|
let ha = &*HIGHLIGHTING_ASSETS;
|
||||||
|
|
||||||
let (ps, ts) = (&ha.syntax_set, &ha.theme_set);
|
let (ps, ts) = (&ha.syntax_set, &ha.theme_set);
|
||||||
|
@ -156,6 +165,10 @@ async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||||
.body(r#"{"error": "code parameter is required"}"#);
|
.body(r#"{"error": "code parameter is required"}"#);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if args.tensorflow_model_dir.is_some() {
|
||||||
|
conf.load_tensorflow_model(args.tensorflow_model_dir.unwrap().as_str());
|
||||||
|
}
|
||||||
|
|
||||||
conf.language = info.language.clone();
|
conf.language = info.language.clone();
|
||||||
if let Some(theme) = info.theme.clone() {
|
if let Some(theme) = info.theme.clone() {
|
||||||
conf.theme = theme;
|
conf.theme = theme;
|
||||||
|
@ -227,6 +240,7 @@ async fn generate(info: web::Query<config::ConfigQuery>) -> impl Responder {
|
||||||
.append_header(("Content-Type", "application/json"))
|
.append_header(("Content-Type", "application/json"))
|
||||||
.body(r#"{"error": "Unable to determine language, please provide one explicitly"}"#)
|
.body(r#"{"error": "Unable to determine language, please provide one explicitly"}"#)
|
||||||
);
|
);
|
||||||
|
|
||||||
let theme = unwrap_or_return!(
|
let theme = unwrap_or_return!(
|
||||||
conf.theme(ts),
|
conf.theme(ts),
|
||||||
HttpResponse::BadRequest()
|
HttpResponse::BadRequest()
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue